import os
import geometry
import wandb
from matplotlib import cm
import cv2
import torchvision
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy()); 
from einops import rearrange, repeat
import piqa
import imageio
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def filter_points(points, iqr_factor=1.5):
    """Filters outliers from a 3D point cloud using the IQR method."""
    # Calculate quantiles
    q1 = torch.quantile(points, 0.25, dim=0)
    q3 = torch.quantile(points, 0.75, dim=0)
    iqr = q3 - q1

    # Calculate lower and upper bounds
    lower_bound = q1 - iqr_factor * iqr
    upper_bound = q3 + iqr_factor * iqr

    # Filter outliers
    mask = torch.all((points >= lower_bound) & (points <= upper_bound), dim=1)
    return mask

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0):

    resolution = list(model_input["rgb"].flatten(0,1).permute(0,2,3,1).shape)
    resolution[0]=ground_truth["rgb"].size(1)*ground_truth["rgb"].size(0)
    nrow=model_input["rgb"].size(1)
    imsl=model_input["rgb"].shape[-2:]
    inv = lambda x : 1/(x+1e-8)

    if "rig_masks" in model_output:model_output["depth_permask"] = model_output["depth"].unsqueeze(-3)*model_output["rig_masks"]

    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    wandb_out["ref/rgb_gt"]= make_grid(ground_truth["rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)

    if "poses_lie" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["poses_lie"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["poses_lie"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "lie_perpix" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    

    if "rig_masks" in model_output:
        wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(),nrow=model_output["rig_masks"].size(2))

    if "level_scores" in model_output:
        wandb_out["est/level_scores"]= make_grid(rearrange(model_output["level_scores"],"b t l x y -> (b t l) 1 x y",).detach(), nrow=3)
    if "composite_weights" in model_output:
        for i,x in enumerate(model_output["composite_weights"]):
            wandb_out["est/composite_weights_%i"%i]= make_grid(rearrange(x,"b t o (x y) -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(), nrow=x.size(2))
    if "mask_spatial_scores" in model_output: 
        for i,x in enumerate(model_output["mask_spatial_scores"]):
            wandb_out["est/mask_spatial_scores_%i"%i]= make_grid(rearrange(x,"b t o x y -> (b t o) 1 x y").detach(), nrow=x.size(2))
    if "is_static" in model_output:
        wandb_out["est/is_static"]= make_grid(rearrange(model_output["is_static"],"b t 1 x y -> (b t) 1 x y").detach(), nrow=model_output["is_static"].size(1))

    if "traj_top_comps" in model_output:
        wandb_out["est/top_comps1"]= make_grid(model_output["traj_top_comps"][:,:3].abs().detach(),normalize=False)
        wandb_out["est/top_comps2"]= make_grid(model_output["traj_top_comps"][:,3:].abs().detach(),normalize=False)

    if "affinity_sim" in model_output:
        wandb_out["est/affinity_sim"] = make_grid(rearrange(model_output["affinity_sim"][0,0],"x1 y1 x y -> (x1 y1) 1 x y").detach(), nrow=model_output["affinity_sim"].size(3))
        wandb_out["est/affinity_sim_rgb"] = make_grid(F.interpolate(model_input["rgb"][:,0]*.5+.5,model_output["affinity_sim"].shape[-2:])*
                                                        rearrange(model_output["affinity_sim"][0,0],"x1 y1 x y -> (x1 y1) 1 x y").detach(), nrow=model_output["affinity_sim"].size(3))
    if "affinity_emb" in model_output:
        # pca vis
        #features=model_output["affinity_emb"].flatten(0,1)
        features=rearrange(model_output["affinity_emb"].flatten(0,1),"bt c x y -> 1 c (bt x) y")
        B, C, H, W = features.shape
        features = features.view(B, C, -1)
        # Center the data
        features_mean = features.mean(dim=2, keepdim=True)
        features = features - features_mean
        covariance = torch.bmm(features, features.transpose(1, 2)) / (H * W - 1)
        # Perform SVD
        U, S, V = torch.svd(covariance)
        # Project the data onto the top principal components
        num_components=min(3,C)
        transformed_features = torch.bmm(U[:, :, :num_components].transpose(1, 2), features)
        # Reshape back to original spatial dimensions
        transformed_features = transformed_features.view(B, num_components, H, W)
        wandb_out["est/affinity_emb"]= make_grid(transformed_features.detach(), nrow=model_output["affinity_emb"].size(1))

    for k,v in model_input.items():
        if "all_mask" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v,"t m x y -> (t m) 1 x y").detach(),nrow=v.size(1))
    for k,v in model_input.items():
        if "masks_vis" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=v.size(1))
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #if "all_masks" in model_input:
    #    for i,x in enumerate(model_input["all_masks"]):
    #        wandb_out["ref/all_masks_%d"%i]= make_grid(rearrange(x,"t m x y -> (t m) 1 x y").detach(),nrow=nrow)
    #if "masks_vis" in model_input:
    #    for i,x in enumerate(model_input["masks_vis"]):
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #        #vid_name,vid="_vid/masks_est_%d"%i, (torch.stack([make_grid(rearrange(y.clip(0,250),"m 1 x y c -> m c x y").detach()) for y in x])).to(torch.uint8).numpy()
    #        #torchvision.io.write_video("output/img/"+vid_name+".mp4",torch.from_numpy(vid).permute(0,2,3,1),4)
    #        #wandb.log({vid_name:wandb.Video(vid, fps=4,format="mp4") })

    if "zoe_depthvis" in model_output: wandb_out["est/zoe_depth"]=make_grid(model_output["zoe_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "res_depthvis" in model_output: wandb_out["est/res_depth"]=make_grid(model_output["res_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "depthvis" in model_output: wandb_out["est/depth"]=make_grid(model_output["depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "depth_permaskvis" in model_output: 
        wandb_out["est/depth_permask"]=make_grid(model_output["depth_permaskvis"].cpu().flatten(0,2).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=model_output["rig_masks"].size(2))

    if "not_sky" in model_input: wandb_out["ref/not_sky"]=make_grid(model_input["not_sky"].cpu().flatten(0,1).detach().float(),nrow=nrow,normalize=False)
    #if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(ch_fst(model_output["corr_weights"],resolution[1]).flatten(0,1).cpu().detach(),normalize=True,nrow=nrow-1)
    if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(model_output["corr_weights"].flatten(0,1).cpu().detach(),normalize=True,nrow=nrow)
    if "bwd_flow" in model_input: wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(model_input["bwd_flow"].flatten(0,1),nrow=nrow))/255

    if "warp_rgbs_" in model_output: wandb_out["est/warp_rgb"]= make_grid(model_output["warp_rgbs_"].flatten(0,2).permute(0,2,1).unflatten(-1,imsl),nrow=nrow)

    if "flow_from_pose" in model_output and not torch.isnan(model_output["flow_from_pose"]).any(): 
        wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255

    if "flow_from_pose_uncomp" in model_output:
        wandb_out["est/flow_est_pose_uncomp"] = flow_vis_torch.flow_to_color(make_grid(rearrange(model_output["flow_from_pose_uncomp"],"m bt c x y -> (bt m) c x y").detach(),
                                                                                nrow=len(model_output["flow_from_pose_uncomp"])))/255
        wandb_out["est/flow_est_pose_uncomp_err"] = make_grid(rearrange((model_input["bwd_flow"]-model_output["flow_from_pose_uncomp"]).abs().sum(dim=2), "m bt x y -> (bt m) 1 x y").detach(),
                nrow=len(model_output["flow_from_pose_uncomp"]))

    if "rig_poses" in model_output:
        rig_colors=torch.rand(len(model_output["rig_poses"].cpu()[0]),3)
        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        for color,our_pos in zip(rig_colors,model_output["rig_poses"][0,:,:,:3,-1].detach().cpu()):
            ax.plot(*our_pos.cpu().T.numpy(),c=color.numpy())
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/rig_poses"] = pose_img
 
        # point cloud plotting
        for color_i in range(3):
            for view_i in range(3):
                if 1:continue
                pts_all = model_output["pts_canon"][0].cpu()
                rig_selection = model_output["rig_masks"][0].max(dim=1)[1].cpu()
                pts = torch.gather(pts_all.permute(1,2,3,0), -1, rig_selection.expand(-1,-1,3).unsqueeze(-1)).squeeze(-1)
                mask=filter_points(pts.flatten(0,1),.4)

                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')

                # Plot camera frustums for each pose
                cmap = cm.get_cmap('viridis', model_output["rig_poses"].cpu()[0].size(1))
                for rig_i,rig_poses in enumerate(model_output["rig_poses"].cpu()[0]):
                    for pose_i,pose in enumerate(rig_poses):
                        if view_i>1:continue
                        frustum_points = torch.tensor([[0, 0, 0, 1], [-1, -1, 2, 1], [1, -1, 2, 1], [1, 1, 2, 1], [-1, 1, 2, 1]]).float()  
                        # Scale frustum size
                        frustum_points[:, :3] *= .05
                        # Transform frustum points to world space using the pose matrix
                        frustum_world = torch.mm(pose, frustum_points.T).T
                        # Extract the corners of the frustum
                        camera_center = frustum_world[0, :3]
                        bottom_left = frustum_world[1, :3]
                        bottom_right = frustum_world[2, :3]
                        top_right = frustum_world[3, :3]
                        top_left = frustum_world[4, :3]

                        # Plot frustum lines
                        frustum_lines = [[camera_center, bottom_left], [camera_center, bottom_right], [camera_center, top_right], [camera_center, top_left],
                                         [bottom_left, bottom_right], [bottom_right, top_right], [top_right, top_left], [top_left, bottom_left]]

                        color=rig_colors[rig_i] if 1 else cmap(frame_i)

                        for line in frustum_lines:
                            ax.plot(*zip(*line), color=color.numpy(),zorder=10,alpha=.9)

                rgb = model_input["rgb"].cpu()[0].flatten(-2,-1).permute(0,2,1)
                stride=100

                pt_colors = (torch.from_numpy(cmap(list(range(len(pts)))))[:,None].expand(-1,pts.size(1),-1).flatten(0,1), # frame color
                             rig_colors[rig_selection.flatten(0,1).squeeze(-1)], # rig color
                             rgb.flatten(0,1).clip(-1,1)*.5+.5, #rgb color
                            )[color_i]
                sc = ax.scatter(*pts.flatten(0,1)[mask][::stride].T, c=pt_colors[mask][::stride],marker='o', s=1)
                ax.view_init(elev=-15 if view_i==0 else -90 if view_i==1 else -90, azim=-90)

                fig.tight_layout()
                fig.canvas.draw()
                image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
                plt.close()
                view_name=["bev","front","front_nopose"]
                color_name=["frame","rig","rgb"]
                wandb_out["vis3d/view_%s_color_%s"%(view_name[view_i],color_name[color_i],)] = pose_img

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        for suffix in ["_aligned",""]:
            pose_imgs=[]
            poses = model_output["poses"].unsqueeze(1) if "intermed_poses_" not in model_output else model_output["intermed_poses_"]
            #try:
            for i in range(len(poses)):
                for j in range(len(poses[0])):
                    #our_pos=poses[i,j,:,0,:3,-1].cpu()
                    our_pos=poses[i,j,:,:3,-1].detach().cpu()
                    if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
                    else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

                    pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
                    ate=(pos_gt_-pos_est_).square().mean().sqrt()
                    print("ATE:", ate)
                    wandb.log({"metrics/ATE": ate},step=step)

                    fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
                    ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
                    ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
                    # Set the same limits for all axes
                    min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
                    ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
                    plt.legend();plt.tight_layout();fig.canvas.draw()
                    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                    pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
                    plt.close()
                    wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)
        # pretty plot for figures
        our_pos=model_output["poses"][0,:,:3,-1].detach().cpu()
        gt_pos=model_input["c2w"][0,:,:3,-1].detach().cpu()
        gt_pos,our_pos = [torch.from_numpy(x) for x in spatial.procrustes(gt_pos.cpu().numpy(),our_pos.numpy())[:2]]

        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        ax.plot(*our_pos.cpu().T.numpy(),c="red",label="Estimated Trajectory");
        ax.plot(*gt_pos.cpu().T.numpy(),c="black",label="GT Trajectory");
        # Set the same limits for all axes
        #min_,max_=gt_pos.min(),gt_pos.max()
        #ax.set_xlim(min_, max_)
        #ax.set_ylim(min_, max_)
        #ax.set_zlim(min_, max_)
        #plt.legend();
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/pose_est_pretty"] = pose_img

    #except: print("pose plotting error") 

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})

def pose_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
    # Log points and boxes in W&B
    point_scene = wandb.Object3D({
        "type": "lidar/beta",
        "points":  model_output["poses"][:,:3,-1].cpu().numpy(),
    })
    wandb.log({"camera positions": point_scene})


    
