import os,io,shutil
import geometry
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import piqa
import imageio
from PIL import Image
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import geometry
import pybullet as p
import data

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[],dont_log=False):
    model_output,model_input,ground_truth = [{k:(v[:1] if len(v.shape) else v) for k,v in x.items()} for x in (model_output,model_input,ground_truth)]

    wandb_out = {}

    # Log images
    for k,v in {k:v for k,v in model_input.items() if "img" in k}.items(): wandb_out["ref/"+k]= make_grid(v.flatten(0,1),normalize=True)

    #wandb_out["ref/fg_seg"]= make_grid(model_input["fg"],normalize=False)
    #wandb_out["est/fg_seg"]= make_grid(model_output["fg"],normalize=False)

    # Plot jointsdf predictions
    wandb_out["ref/joint_sdfs"]= make_grid(model_input ["joint_sdfs"].flatten(0,1),normalize=True)
    wandb_out["est/joint_sdfs"]= make_grid(model_output["joint_sdfs"].flatten(0,1),normalize=True)

    if 1: 
        # Solve for extrinsics 
        # 6a. Recover per-link transforms

        # Camera estimation from pointmaps
        for src_i,src in enumerate([model_input,model_output][:]):

            points_cam=ch_sec(model_input["points_cam"])[0,0].detach()
            joint_sdfs=ch_sec(src["joint_sdfs"])[0].detach()
            fg=model_input["fg"][0,0].detach().flatten()
            weights = fg[None,:] * (1/(joint_sdfs+1e-4).norm(dim=-1)**2)

            #keypoint_estimates = ((joint_sdfs + points_cam)*(weights[:,:,None]/weights.sum(1)[:,None,None])).sum(1)
            keypoint_estimates=torch.stack([(joint_sdfs + points_cam)[i,x] for i,x in enumerate(weights.max(dim=1)[1])]) 

            est_pose = geometry.efficient_procrustes(model_input["link_positions"].float(), keypoint_estimates[None].float(), weights=(weights.max(dim=1)[0]/weights.max())[None,:,None].float())[1]
            #est_pose==model_input["cam2world_cv"]

            est_view_matrix_=data.Tc@est_pose.cpu().numpy()
            gt_view_matrix_=data.Tc@model_input["cam2world_cv"].cpu().numpy()

            # Rerender mesh image with estimated extrinsics
            wandb_out["%s/cam_rerender"%["ref","est"][src_i]]= data.format_pybullet_render(model_input["joint_states"][0].cpu().numpy(), est_view_matrix_[0],model_input["rob_id"][0].item(),cam2world_cv=est_pose.cpu().numpy()[0],urdf=1)["img"][0]*.5+.5

            # Plot estimated robot base point cloud
            fig = plt.figure();ax = fig.add_subplot(111, projection='3d');ax.view_init(elev=160, azim=-90);s=1
            ax.scatter(*keypoint_estimates.cpu().numpy().T, marker='o', alpha=.99, label='Point Cloud')
            #ax.scatter(*link_positions_cam.T, marker='o', alpha=.99, label='Point Cloud')
            ax.scatter(*points_cam[fg][::s].T.cpu().numpy(), c=ch_sec(model_input["img"])[0][fg].cpu().numpy()*.5+.5, marker='o', alpha=.15)
            geometry.plot_frustum(np.linalg.inv(est_view_matrix_[0])@data.Tc,scale=.1,ax=ax,color="red",label="Pred Cam2Base")
            geometry.plot_frustum(np.linalg.inv(gt_view_matrix_[0])@data.Tc,scale=.1,ax=ax,color="blue",label="GT Cam2Base")
            plt.legend()
            try:
                plt.savefig("output/img/tmp.png");plt.close();
                img_arr=plt.imread("output/img/tmp.png")
                wandb_out["%s/cam_plot"%["ref","est"][src_i]] = torch.from_numpy(img_arr[...,:3]).permute(2,0,1)
            except: print("failed saving/logging 3d plot image")
            plt.close()

        wandb_out["est/render_overlay"]=wandb_out["%s/cam_rerender"%["ref","est"][0]]*.5+wandb_out["%s/cam_rerender"%["ref","est"][1]]*.5

    if 1:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0)[...,:3].detach().cpu().numpy().clip(0,1));
        print("saving locally")
        torch.save((model_input,model_output),"output/inpout.pt")
        zz

    for k,v in wandb_out.items():print(k,v.shape)
    wandb_imgdict = {prefix+k:wandb.Image(v.permute(1, 2, 0)[...,:3].float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
    if not dont_log:wandb.log(wandb_imgdict)
    print("done logging images")
    return wandb_imgdict
