import os,io,shutil
import geometry
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import piqa
import imageio
from PIL import Image
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import geometry
import pybullet as p
import data

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[],dont_log=False):
    model_output,model_input,ground_truth = [{k:(v[:1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output,model_input,ground_truth)]

    wandb_out = {}

    # Log images
    for k,v in {k:v for k,v in model_input.items() if "img" in k}.items(): wandb_out["ref/"+k]= make_grid(v.flatten(0,1))

    #wandb_out["ref/fg_seg"]= make_grid(model_input["fg"],normalize=False)
    #wandb_out["est/fg_seg"]= make_grid(model_output["fg"],normalize=False)

    # Segmentation plotting
    seg_cmap=plt.cm.get_cmap('tab20', model_input["segs"].size(1))
    seg_colors = [torch.tensor(seg_cmap(i))[:3] for i in range(model_input["segs"].size(1))]
    seg_img = ch_sec(torch.ones_like(model_input["points_cam"]))[0]
    for seg,seg_color in zip(model_input["segs"].unbind(1),seg_colors): seg_img[ch_sec(seg).squeeze(1)]=seg_color.float().cuda()
    wandb_out["ref/seg"]= make_grid(ch_fst(seg_img,model_input["points_cam"].size(-2)),normalize=False)
    seg_img = ch_sec(torch.ones_like(model_input["points_cam"]))[0]
    wandb_out["ref/seg_raw"]= make_grid(model_input["segs"].flatten(0,1).unsqueeze(1),normalize=False)

    # Plot pointmaps
    wandb_out["ref/points_cam"]= make_grid(model_input ["points_cam"],normalize=True)
    wandb_out["ref/points_link"]= make_grid(model_input ["points_link"],normalize=True)
    wandb_out["est/points_link"]= make_grid(model_output["points_link"],normalize=True)

    if 0: 
        # Solve for extrinsics 
        # 6a. Recover per-link transforms

        # Camera estimation from pointmaps
        for src_i,src in enumerate([model_input,model_output][:]):
            est_links=[] # this will come from network's x1,x2,seg
            segs=(src["segs_thresh"] if src_i else src["segs"])[0].detach().cpu().flatten(1,2).numpy()[1:] # ignoring added background seg channel

            points_cam =ch_sec(src["points_cam"][0].detach().cpu().numpy())
            # NOTE using GT points cam here for testing
            #points_cam =ch_sec(model_input["points_cam"][0].detach().cpu().numpy())
            points_link=ch_sec(src["points_link"][0].detach().cpu().numpy())
            hom = lambda x: np.concatenate([x, np.ones((x.shape[0], 1))], axis=1)

            # Get points in robot base from each link using known joint states
            ptsx,ptsy=[],[]
            for seg,curr_link_transform in zip(segs,model_input["curr_link_transforms"][0]):
                x1=points_cam[seg][...,:3]
                x2=np.einsum('ij,nj->ni', curr_link_transform.detach().cpu().numpy(), hom(points_link[seg]))[:, :3]
                #tmp1 = geometry.procrustes_umeyama(x1,x2)
                #print(tmp1 -  np.linalg.inv(model_input["view_matrix_"][0].cpu().numpy()) @ data.Tc)
                ptsx.append(x1)
                ptsy.append(x2)

            if len(ptsx)==0: 
                est_view_matrix_=np.eye(4)
                print("no points, using identity matrix")
            else:
                est_view_tc = geometry.procrustes_umeyama(np.concatenate(ptsx)[:,:3],np.concatenate(ptsy)[:,:3])
                #print(tmp1 - np.linalg.inv(view_matrix_) @ Tc)
                est_view_matrix_ = np.linalg.inv( est_view_tc @ np.linalg.inv(data.Tc) )

            est_cam = np.linalg.inv(est_view_matrix_)@data.Tc
            gt_cam = np.linalg.inv(model_input["view_matrix_"][0].cpu().numpy())@data.Tc

            # Rerender mesh image with estimated extrinsics
            wandb_out["%s/cam_rerender"%["ref","est"][src_i]]= data.format_pybullet_render(model_input["joint_states"][0].cpu().numpy(),
                                                            est_view_matrix_,0)["img"][0]

            # Plot estimated robot base point cloud
            fig = plt.figure();ax = fig.add_subplot(111, projection='3d');ax.view_init(elev=160, azim=-90);s=1
            for joint_i in range(len(segs)): 
                ax.scatter(*points_cam[segs[joint_i].reshape(-1)][::s].T, c=seg_colors[joint_i], marker='o', alpha=.99)
            #    ax.scatter(*(est_extrinsics@hom(points_cam).T).T[...,:3].reshape(-1,3)[segs[joint_i].reshape(-1)][::s].T, c=seg_colors[joint_i], marker='o', alpha=.99)
            geometry.plot_frustum(est_cam,scale=.1,ax=ax,color="red",label="Pred Cam2Base")
            #geometry.plot_frustum(np.linalg.inv(model_input["cam2world_cv"][0].cpu().numpy()),scale=.1,ax=ax,color="blue",label="GT Cam2Base")
            geometry.plot_frustum(gt_cam,scale=.1,ax=ax,color="blue",label="GT Cam2Base")
            plt.legend()
            try:
                plt.savefig("output/img/tmp.png");plt.close();
                img_arr=plt.imread("output/img/tmp.png")
                wandb_out["%s/cam_plot"%["ref","est"][src_i]] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
            except: print("failed saving/logging 3d plot image")
            plt.close()


    if 0:

        
        #print("Solving for transfs")
        for seg_i,seg in enumerate(est_seg_img.unbind(1)):
            x1,x2=ch_sec(model_input["obj_pointmap"])[0][seg.flatten()][None].float(),ch_sec(model_input["cam_pointmap"])[0][seg.flatten()][None].float()
            est_transf = geometry.procrustes(x1,x2)[1]

            # Plot 3d point cloud and GT bounding box vs est over canonical point cloud
            #for use_seg in [0,1]:
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            s=10
            ax.scatter(*ch_sec(model_input["cam_pointmap"])[0][seg.flatten()][::s].cpu().numpy().T, c=ch_sec(model_input["targ_img"])[0,0].reshape(-1,3)[seg.flatten()].cpu().numpy()[::s], marker='o', alpha=0.1, label='Point Cloud')

            vertices = np.array([ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1] ])-.5
            homogeneous_cube = np.hstack([vertices, np.ones((len(vertices), 1))])
            gt_cube = np.einsum('ij,nj->ni',  model_input["obj_transfs"][0,seg_i].cpu().numpy(), homogeneous_cube)[:, :3]
            est_cube = np.einsum('ij,nj->ni', est_transf[0].cpu().numpy(), homogeneous_cube)[:, :3]

            # Define the edges of the cube
            edges = [ [0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [0, 4], [1, 5], [2, 6], [3, 7] ]
            # Create a 3D plot
            for edge in edges: ax.plot(*gt_cube[edge].T, color='black',alpha=.4)
            for edge in edges: ax.plot(*est_cube[edge].T, color=seg_colors[seg_i].numpy(),alpha=.4)
            ax.view_init(elev=-45, azim=-90)
            try:
                plt.savefig("output/img/tmp.png");plt.close();
                img_arr=plt.imread("output/img/tmp.png")
                wandb_out["est/transf_cubes_%d"%seg_i] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
            except: print("failed saving/logging 3d plot image")
            plt.close()

            # Visualize transfs as unit tri deformations
            unit_tri = torch.tensor([ [0, 0, 0,1],  [1, 0, 0,1],  [0, 1, 0,1],  ]).float().numpy()
            model_output["pred_tris"]=geometry.transf_to_tris(est_transf)[0]
            model_input["transf_tris"]=geometry.transf_to_tris(model_input["obj_transfs"][:,seg_i].float())[0]
            if not dont_log:wandb.log({"tri_vis_metric_"+prefix: (model_output["pred_tris"]-model_input["transf_tris"]).square().mean()}, step=step)
            pred_tri = model_output["pred_tris"].detach().cpu().numpy()[0]
            gt_tri = model_input["transf_tris"].detach().cpu().numpy()[0]
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(111, projection='3d')
            # add label here
            ax.plot_trisurf(unit_tri[:, 0], unit_tri[:, 1], unit_tri[:, 2], color="red", alpha=0.5)
            ax.plot_trisurf(pred_tri[:, 0], pred_tri[:, 1], pred_tri[:, 2], color=seg_colors[0].numpy() if 0 else "green", alpha=0.5)
            ax.plot_trisurf(gt_tri[:, 0], gt_tri[:, 1], gt_tri[:, 2],       color=seg_colors[0].numpy() if 0 else "blue", alpha=0.5)
            try:
                plt.savefig("output/img/tmp.png");plt.close();
                img_arr=plt.imread("output/img/tmp.png")
                wandb_out["est/tri_obj1_%d"%seg_i] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
                print("logging images",len(wandb_out))
            except: print("failed saving/logging 3d plot image")
            plt.close()

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        torch.save((model_input,model_output),"output/inpout.pt")
        zz

    for k,v in wandb_out.items():print(k,v.shape)
    wandb_imgdict = {prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
    if not dont_log:wandb.log(wandb_imgdict)
    print("done logging images")
    return wandb_imgdict
