import torch
import sys
sys.path.append("/home/cameronsmith/repos/controll3r/")
from vis import *


import pybullet as p
import pybullet_data
physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
rob_id = p.loadURDF("franka_panda/panda.urdf")

model_input_,model_output_=[{k:v.detach().cpu() for k,v in d.items()} for d in torch.load("scripts/sample_inp_outp.pt")]
for k in ["points_cam","points_link"]:
    for d in [model_input_,model_output_]:
        d[k] = d[k]*model_input_["fg"]
from tqdm import tqdm
for n_link_used in list(range(1,12))[:]:
    errors,errors2=[],[]
    for b_i in range(15):
        for weighted_procrustes in [0]:
            for use_gt_pointcam in [0]:
                plot=0

                # todo stop using gt seg here even if just for vis
                model_output,model_input= [{k:(v[b_i:b_i+1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output_,model_input_)]
                ground_truth=model_input

                wandb_out = {}

                #wandb_out["ref/fg_seg"]= make_grid(model_input["fg"],normalize=False)
                #wandb_out["est/fg_seg"]= make_grid(model_output["fg"],normalize=False)

                # Segmentation plotting
                model_output["segs_thresh"]=est_seg_img=model_output["segs"]>.8
                if plot:
                    seg_cmap=plt.cm.get_cmap('tab20', model_input["segs"].size(1))
                    seg_colors = [torch.tensor(seg_cmap(i))[:3] for i in range(model_input["segs"].size(1))]
                    seg_img = ch_sec(torch.ones_like(model_input["points_cam"]))[0]
                    for seg,seg_color in zip(model_input["segs"].unbind(1),seg_colors): seg_img[ch_sec(seg).squeeze(1)]=seg_color
                    wandb_out["ref/seg"]= make_grid(ch_fst(seg_img,model_input["points_cam"].size(-2)),normalize=False)
                    seg_img = ch_sec(torch.ones_like(model_input["points_cam"]))[0]
                    for seg,seg_color in zip(est_seg_img.unbind(1),seg_colors): seg_img[ch_sec(seg).squeeze(1)]=seg_color
                    wandb_out["est/seg"]= make_grid(ch_fst(seg_img,model_input["points_cam"].size(-2)),normalize=False)
                    #wandb_out["est/seg_raw"]= make_grid(model_output["segs"].flatten(0,1).unsqueeze(1),normalize=False)
                    #wandb_out["ref/seg_raw"]= make_grid(model_input["segs"].flatten(0,1).unsqueeze(1),normalize=False)

                    # Plot pointmaps
                    wandb_out["ref/points_cam"]= make_grid(model_input ["points_cam"],normalize=True)
                    wandb_out["est/points_cam"]= make_grid(model_output["points_cam"],normalize=True)
                    wandb_out["ref/points_link"]= make_grid(model_input ["points_link"],normalize=True)
                    wandb_out["est/points_link"]= make_grid(model_output["points_link"],normalize=True)

                # Solve for extrinsics 
                # 6a. Recover per-link transforms

                # Camera estimation from pointmaps
                for src_i,src in enumerate([model_input,model_output][:]):
                    est_links=[] # this will come from network's x1,x2,seg
                    segs=(src["segs_thresh"] if src_i else src["segs"])[0].detach().cpu().flatten(1,2).numpy()[1:] # ignoring added background seg channel

                    points_cam =ch_sec((model_input if use_gt_pointcam else src)["points_cam"][0].detach().cpu().numpy())
                    points_link=ch_sec(src["points_link"][0].detach().cpu().numpy())
                    hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)

                    # Get points in robot base from each link using known joint states
                    ptsx,ptsy,seg_errs=[],[],[]
                    for seg,curr_link_transform in zip(segs[:n_link_used],model_input["curr_link_transforms"][0]):
                        x1=points_cam[seg][...,:3]
                        x2=np.einsum('ij,nj->ni', curr_link_transform.detach().cpu().numpy(), hom(points_link[seg]))[:, :3]
                        #tmp1 = geometry.procrustes_umeyama(x1,x2);print(tmp1 -  np.linalg.inv(model_input["view_matrix_"][0].cpu().numpy()) @ data.Tc)
                        tmp2 = geometry.procrustes(torch.from_numpy(x1)[None].float(),torch.from_numpy(x2)[None].float());
                        seg_errs.append(tmp2[0])
                        #print(len(seg_errs),len(x1),seg_errs[-1].mean())
                        #print(tmp1 -  np.linalg.inv(model_input["view_matrix_"][0].cpu().numpy()) @ data.Tc)

                        ptsx.append(x1)
                        ptsy.append(x2)

                    if len(ptsx)==0: 
                        est_view_matrix_=np.eye(4)
                        print("no points, using identity matrix")
                    else:

                        if weighted_procrustes:
                            seg_errs_=torch.concat([x[0].T for x in seg_errs]).sum(dim=1)+1e-20
                            if seg_errs.shape[0]==0:
                                est_view_matrix_=np.eye(4)
                                print("no seg err pts, using identity matrix")
                            else:
                                weights = 1-seg_errs_/seg_errs_.max()
                                est_view_tc = geometry.procrustes(torch.from_numpy(np.concatenate(ptsx)[:,:3])[None].float(),torch.from_numpy(np.concatenate(ptsy)[:,:3])[None].float(),weights=weights[None,:,None])[1].numpy()[0]
                                est_view_matrix_ = np.linalg.inv( est_view_tc @ np.linalg.inv(data.Tc) )
                                if src_i:errors.append(np.sum(np.abs(est_view_matrix_ - model_input["view_matrix_"][0].cpu().numpy())))
                        else:
                            est_view_tc = geometry.procrustes_umeyama(np.concatenate(ptsx)[:,:3],np.concatenate(ptsy)[:,:3])
                            est_view_matrix_ = np.linalg.inv( est_view_tc @ np.linalg.inv(data.Tc) )
                            #errors2.append(np.sum(np.abs(est_view_matrix_ - model_input["view_matrix_"][0].cpu().numpy())))
                            if src_i:errors.append(np.sum(np.abs(est_view_matrix_ - model_input["view_matrix_"][0].cpu().numpy())))

                        #if src_i:print(b_i,n_link_used,errors[-1],errors2[-1])
                    if not plot:continue

                    est_cam = np.linalg.inv(est_view_matrix_)@data.Tc
                    gt_cam = np.linalg.inv(model_input["view_matrix_"][0].cpu().numpy())@data.Tc

                    # Rerender mesh image with estimated extrinsics
                    wandb_out["%s/cam_rerender"%["ref","est"][src_i]] = data.format_pybullet_render(model_input["joint_states"][0].cpu().numpy(), est_view_matrix_,0)["img"][0]

                    # Plot estimated robot base point cloud
                    fig = plt.figure();ax = fig.add_subplot(111, projection='3d');ax.view_init(elev=160, azim=-90);s=1
                    for joint_i in range(len(segs)): 
                        ax.scatter(*points_cam[segs[joint_i].reshape(-1)][::s].T, c=seg_colors[joint_i], marker='o', alpha=.99)
                    geometry.plot_frustum(est_cam,scale=.1,ax=ax,color="red",label="Pred Cam2Base")
                    geometry.plot_frustum(gt_cam,scale=.1,ax=ax,color="blue",label="GT Cam2Base")
                    plt.legend()
                    plt.savefig("output/img/tmp.png");plt.close();
                    img_arr=plt.imread("output/img/tmp.png")
                    wandb_out["%s/cam_plot"%["ref","est"][src_i]] =F.interpolate(torch.from_numpy(img_arr[...,:3]).permute(2,0,1)[None],(256,256))[0]
                    plt.close()

                if plot:
                    wandb_out["est/img"]=wandb_out["%s/cam_rerender"%["ref","est"][0]]*.5+wandb_out["%s/cam_rerender"%["ref","est"][1]]*.5
                    for k,v in {k:v for k,v in model_input.items() if "img" in k}.items(): wandb_out["ref/"+k]= make_grid(v.flatten(0,1))
                    all_img=torch.cat((torch.cat([v for k,v in wandb_out.items() if "ref" in k],2), torch.cat([v for k,v in wandb_out.items() if "est" in k],2)),1)
                    plt.imsave("/home/cameronsmith/tmp/tmpimg/tmp_%d_%d_%s_%s.png"%(b_i,n_link_used,["","weighted"][weighted_procrustes],["","gt_cam"][use_gt_pointcam]),all_img.permute(1,2,0).cpu().numpy())
    print("Error %d : %f",n_link_used,sum(errors)/len(errors))

