import os
import geometry
import wandb
from matplotlib import cm
import cv2
import torchvision
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy()); 
from einops import rearrange, repeat
import models
import piqa
import imageio
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import viser.transforms as tf
from sklearn.ensemble import IsolationForest

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

# im pretty sure this is wrong, we need to filter based on nearest neighbor distances, not whatever this is, this is just contracting to the mean 3d point
# lets go back to that isolation outlier fit
def filter_points(points, iqr_factor=.3):
    iso_forest = IsolationForest(contamination=max(.01,min(iqr_factor,.49))).fit(points)
    outlier_scores = iso_forest.decision_function(points)
    return iso_forest.predict(points) != -1

    """Filters outliers from a 3D point cloud using the IQR method."""
    # Calculate quantiles
    q1 = torch.quantile(points, 0.25, dim=0)
    q3 = torch.quantile(points, 0.75, dim=0)
    iqr = q3 - q1

    # Calculate lower and upper bounds
    lower_bound = q1 - iqr_factor * iqr
    upper_bound = q3 + iqr_factor * iqr

    # Filter outliers
    mask = torch.all((points >= lower_bound) & (points <= upper_bound), dim=1)
    return mask

# serve todos: add point filtering callback 
def viser_update(server,loss_, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,wandb_imgs=None):
    # Losses
    print("doing viser")
    loss=defaultdict(list)
    for x in loss_:
        for k,v in x.items():
            loss[k].append(v)
    def make_loss_fig(k,loss):
        fig = px.line( y=loss, x=list(range(len(loss))), labels={"x": "x", "y": k}, title=k)
        fig.layout.title.automargin = True  
        fig.update_layout( margin=dict(l=20, r=20, t=20, b=20),) 
        return fig

    if step%10==0:
        loss_figs = {k:make_loss_fig(k,v) for k,v in loss.items()} 

    # Point cloud
    #store world crds in frame 0, then transform with einsum to any frame with slider 
    #ros, rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"], model_output["poses"])
    #scene["world_crds"]=ros+rds*model_output["depth"].cuda()
    scene=model_output
    
    #adj_opt_flow = project( torch.einsum("btpij,btpj->btpi",pose_perpix[:,:-1].inverse()@pose_perpix[:,1:],hom(eye_surf[:,1:]))[...,:3], 
    get_world_pix = lambda t,s: torch.einsum("btpij,btpj->btpi",(s["pose_perpix"][:,[t]].inverse()@s["pose_perpix"]).flatten(2,3),models.hom(s["eye_surf"]))[...,:3]
    scene["world_crds"]=get_world_pix(0,model_output)
    #scene["world_crds_all"]=model_output["eye_surf"][None]
    scene["rgb"] = model_input["rgb"].flatten(0,1).flatten(-2,-1).permute(0,2,1)[None]
    scene["intrinsics"] = model_input["intrinsics"]
    scene["intrinsics"] = model_input["intrinsics"]
    #if "c2w" in model_input: model_output["gt_poses"],model_output["gt_intrinsics"]=model_input["c2w"],model_input["gt_intrinsics"]
    scene={k:v[0].detach().cpu() for k,v in scene.items() if type(v)==torch.Tensor and len(v.shape)>1}

    n_rig = model_output["rig_masks"].size(2)
    cmap = plt.get_cmap('viridis', n_rig)
    colors_per_rig = [cmap(i) for i in range(n_rig)]
    colors_per_rig = torch.from_numpy(np.concatenate(colors_per_rig).reshape(-1,4))[...,:3]
    rig_masks = rearrange(F.interpolate(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=64),model_input["rgb"].shape[-2:]),"(b t o) 1 x y -> b t o (x y) 1",b=model_input["rgb"].size(0),t=model_input["rgb"].size(1))
    colors_per_rig_perpix = (colors_per_rig[None,None,:,None].cuda()*rig_masks).sum(2).flatten(0,1)

    rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).flatten(1,2)*.5+.5
    trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).flatten(1,2)/5+.5
    color_choices = {"rgb":scene["rgb"].flatten(0,1)*.5+.5,"rig_group":colors_per_rig_perpix.flatten(0,1),"rot_lie":rot_vis.flatten(0,1),"trans":trans_vis.flatten(0,1)}
    color_choices = {k:(v*255).detach().cpu().int().numpy() for k,v in color_choices.items()}

    points = points3d = scene["world_crds"].flatten(0,1).numpy()
    colors = color_choices["rgb"]
    images = ((scene["rgb"].unflatten(1,scene["flow_inp_"].shape[-2:])*.5+.5)*255).int().numpy() 

    #stride=5
    #points=points[::stride]
    #colors=colors[::stride]

    def visualize_frames() -> None:
        # Remove existing image frames.
        for frame in server.frames: frame.remove()
        server.frames.clear()

        def attach_callback( frustum, frame) -> None:
            @frustum.on_click
            def _(_) -> None:
                for client in server.get_clients().values():
                    client.camera.wxyz = frame.wxyz
                    client.camera.position = frame.position

        #for src_i,pose_src in enumerate(["poses","gt_poses"][:1]):
        for src_i,(pose_color,poses) in enumerate(zip(colors_per_rig.cpu().unbind(0), model_output["poses_all"].detach().cpu().unbind(0))):
            for frame_i,pose in enumerate(poses):

                T_world_camera = tf.SE3.from_matrix(pose.inverse().numpy())
                frame = server.scene.add_frame( f"/flowmap/{src_i}_frame_{frame_i}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(),
                    axes_length=0.1, axes_radius=0.005, show_axes=src_i==0,)
                server.frames.append(frame)

                # For pinhole cameras, cam.params will be (fx, fy, cx, cy).
                H, W = images.shape[1],images.shape[2]
                fy = scene["intrinsics"].numpy()[0,1,1]*H
                downsample_factor=1
                #image = images[frame_i][::downsample_factor, ::downsample_factor]
                frustum = server.scene.add_camera_frustum(
                    f"/flowmap/{src_i}_frame_{frame_i}/frustum", fov=2 * np.arctan2(H / 2, fy), aspect=W / H, scale=0.05,
                    #image=image, 
                    #color=[255,0,0],
                    color=(pose_color*255).int(),
                    )
                attach_callback(frustum, frame)

            if src_i==0 and not server.cam_set and len(server.get_clients()):
                client=server.get_clients()[0]
                T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position)
                T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
                T_current_target = T_world_current.inverse() @ T_world_target
                T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log())
                with client.atomic():
                    client.camera.wxyz = T_world_set.rotation().wxyz
                    client.camera.position = T_world_set.translation()
                client.flush()  # Optional!
                client.camera.look_at = frame.position
                server.cam_set=True

    if step==0:
        server.loss_plots=[]
        server.cam_set=False
        if step%10==0:
            for k,v in loss_figs.items():
                server.loss_plots.append(server.gui.add_plotly(figure=v))
        server.gui_timestep_slider                     = server.gui.add_slider( "Timestep", min=0, max=len(images)-1, step=1, initial_value=0)
        server.gui_color_choice                        = server.gui.add_dropdown( "Color Vis", initial_value="rgb",options=color_choices.keys())
        server.gui_outlier_filter = gui_outlier_filter = server.gui.add_checkbox( "Outlier Filtering", initial_value=False,)
        server.gui_outlier_slider = gui_outlier_slider = server.gui.add_slider( "Outlier filter threshold", min=.00001, max=2, step=.01, initial_value=1,)
        server.gui_points         = gui_points         = server.gui.add_slider( "Max points", min=1, max=len(points3d), step=1, initial_value=len(points3d),)
        server.gui_frames         = gui_frames         = server.gui.add_slider( "Max frames", min=1, max=len(images), step=1, initial_value=min(len(images), 100),)
        server.gui_point_size     = gui_point_size     = server.gui.add_slider( "Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01)
        server.point_mask         = point_mask = np.random.choice(points.shape[0], gui_points.value, replace=True)
        server.points=points
        server.colors=colors
        server.point_cloud        = point_cloud = server.scene.add_point_cloud( name="/flowmap/pcd", points=points[point_mask], colors=colors[point_mask], point_size=gui_point_size.value,)
        server.frames = [] 

        @server.gui_color_choice.on_update
        def _(_) -> None:
            server.point_cloud.colors = color_choices[server.gui_color_choice.value]

        #@server.gui_timestep_slider.on_update
        #def _(_) -> None:
        #    server.valid_points = np.arange(points.shape[0])
        #    server.point_mask = np.random.choice(server.valid_points, min(len(server.valid_points)-1,server.gui_points.value), replace=False)
        #    server.point_cloud.points = get_world_pix(server.gui_timestep_slider.value,model_output).flatten(0,2).detach().cpu().numpy()[server.point_mask]

        @server.gui_point_size.on_update
        def _(_) -> None:
            server.point_cloud.point_size = server.gui_point_size.value

        #@server.gui_outlier_filter.on_update
        #@server.gui_outlier_slider.on_update
        #def _(_) -> None:
        #    valid_point_idxs = np.arange(server.points.shape[0])
        #    if server.gui_outlier_filter.value:
        #        valid_point_idxs=valid_point_idxs[~server.point_cloud.outliers]
        #    point_mask = np.random.choice(valid_point_idxs, min(len(valid_point_idxs)-1,server.gui_points.value), replace=False)
        #    print(len(point_mask),"point mask len")
        #    server.point_cloud.points = server.points[point_mask]
        #    server.point_cloud.colors = server.colors[point_mask]
        #    iso_forest = IsolationForest(contamination=server.gui_outlier_slider.value).fit(server.points)
        #    outlier_scores = iso_forest.decision_function(server.points)
        #    is_outlier = iso_forest.predict(server.points) == -1
        #    server.point_cloud.outliers=is_outlier
        #    print(is_outlier.sum(),"outlier sum")
        #@server.gui_points.on_update
        #@server.gui_outlier_slider.on_update
        #@server.gui_outlier_filter.on_update
        #def _(_) -> None:
        #    valid_point_idxs = np.arange(server.points.shape[0])
        #    if server.gui_outlier_filter.value:
        #        valid_point_idxs=valid_point_idxs[~server.point_cloud.outliers]
        #    point_mask = np.random.choice(valid_point_idxs, min(len(valid_point_idxs)-1,server.gui_points.value), replace=False)
        #    print(len(point_mask),"point mask len")
        #    server.point_cloud.points = server.points[point_mask]
        #    server.point_cloud.colors = server.colors[point_mask]
    else:
        if step%10==0:
            for i,(k,v) in enumerate(loss_figs.items()):
                server.loss_plots[i].figure=v

        #if step%5==0: visualize_frames()

        if server.gui_outlier_filter.value:
            print("filtering points for vis")
            server.valid_points = np.arange(points.shape[0])[filter_points(torch.from_numpy(points),server.gui_outlier_slider.value)]
        else: server.valid_points = np.arange(points.shape[0])
        server.point_mask = np.random.choice(server.valid_points, min(len(server.valid_points)-1,server.gui_points.value), replace=False)
        print("len pointmask",len(server.point_mask))

        server.point_cloud.points = get_world_pix(server.gui_timestep_slider.value,model_output).flatten(0,2).detach().cpu().numpy()[server.point_mask]
        server.point_cloud.colors = color_choices[server.gui_color_choice.value][server.point_mask]
        server.point_cloud.point_size = server.gui_point_size.value

    # add wandb imgs
    if wandb_imgs is not None and 1:
        if step==0:server.wandb_plots={}
        for k,v in {k:(v.detach().cpu()*255).permute(1,2,0).to(torch.uint8).numpy() for k,v in wandb_imgs.items()}.items():
            fig = px.imshow(v)
            if k not in server.wandb_plots: server.wandb_plots[k]=server.gui.add_plotly(figure=fig)
            else:server.wandb_plots[k].figure=fig

    print("done viser")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0):

    resolution = list(model_input["rgb"].flatten(0,1).permute(0,2,3,1).shape)
    resolution[0]=ground_truth["rgb"].size(1)*ground_truth["rgb"].size(0)
    nrow=model_input["rgb"].size(1)
    imsl=model_input["rgb"].shape[-2:]
    inv = lambda x : 1/(x+1e-8)

    #if "rig_masks" in model_output:model_output["depth_permask"] = model_output["depth"].unsqueeze(-3)*model_output["rig_masks"]

    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    wandb_out["ref/rgb_gt"]= make_grid(ground_truth["rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)

    if "poses_lie" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["poses_lie"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["poses_lie"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "lie_perpix" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    

    if "rig_masks" in model_output:
        #wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=64).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_rgb"]= make_grid(rearrange(model_output["rig_masks"].flatten(0,1)*(F.interpolate(model_input["rgb"].flatten(0,1),(64,64)).flatten(-2,-1).permute(0,2,1).unsqueeze(1)*.5+.5),
                                                "bt o (x y) c -> (bt o) c x y",x=64).detach(),nrow=model_output["rig_masks"].size(2))
    if "rig_masks_unnorm" in model_output:
        #wandb_out["est/rig_masks_unnorm"]= make_grid(rearrange(model_output["rig_masks_unnorm"],"b t o (x y) 1 -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_unnorm"]= make_grid(rearrange(model_output["rig_masks_unnorm"],"b t o (x y) 1 -> (b t o) 1 x y",x=64).detach(),nrow=model_output["rig_masks"].size(2))

    if "level_scores" in model_output:
        wandb_out["est/level_scores"]= make_grid(rearrange(model_output["level_scores"],"b t l x y -> (b t l) 1 x y",).detach(), nrow=3)
    if "composite_weights" in model_output:
        for i,x in enumerate(model_output["composite_weights"]):
            wandb_out["est/composite_weights_%i"%i]= make_grid(rearrange(x,"b t o (x y) -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(), nrow=x.size(2))
    if "mask_spatial_scores" in model_output: 
        for i,x in enumerate(model_output["mask_spatial_scores"]):
            wandb_out["est/mask_spatial_scores_%i"%i]= make_grid(rearrange(x,"b t o x y -> (b t o) 1 x y").detach(), nrow=x.size(2))
    if "is_static" in model_output:
        wandb_out["est/is_static"]= make_grid(rearrange(model_output["is_static"],"b t 1 x y -> (b t) 1 x y").detach(), nrow=model_output["is_static"].size(1))

    if "traj_top_comps" in model_output:
        wandb_out["est/top_comps1"]= make_grid(model_output["traj_top_comps"][:,:3].abs().detach(),normalize=False)
        wandb_out["est/top_comps2"]= make_grid(model_output["traj_top_comps"][:,3:].abs().detach(),normalize=False)

    if "affinity_sim" in model_output:
        wandb_out["est/affinity_sim"] = make_grid(rearrange(model_output["affinity_sim"][0,0],"x1 y1 x y -> (x1 y1) 1 x y").detach(), nrow=model_output["affinity_sim"].size(3))
    if "affinity_emb" in model_output:
        # pca vis
        features=rearrange(model_output["affinity_emb"].flatten(0,1),"bt c x y -> 1 c (bt x) y")
        B, C, H, W = features.shape
        features = features.view(B, C, -1)
        # Center the data
        features_mean = features.mean(dim=2, keepdim=True)
        features = features - features_mean
        covariance = torch.bmm(features, features.transpose(1, 2)) / (H * W - 1)
        # Perform SVD
        U, S, V = torch.svd(covariance)
        # Project the data onto the top principal components
        num_components=3
        transformed_features = torch.bmm(U[:, :, :num_components].transpose(1, 2), features)
        # Reshape back to original spatial dimensions
        transformed_features = transformed_features.view(B, num_components, H, W)
        wandb_out["est/affinity_emb"]= make_grid(transformed_features.detach(), nrow=model_output["affinity_emb"].size(1))

    for k,v in model_input.items():
        if "all_mask" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v,"t m x y -> (t m) 1 x y").detach(),nrow=v.size(1))
    for k,v in model_input.items():
        if "masks_vis" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=v.size(1))
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #if "all_masks" in model_input:
    #    for i,x in enumerate(model_input["all_masks"]):
    #        wandb_out["ref/all_masks_%d"%i]= make_grid(rearrange(x,"t m x y -> (t m) 1 x y").detach(),nrow=nrow)
    #if "masks_vis" in model_input:
    #    for i,x in enumerate(model_input["masks_vis"]):
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #        #vid_name,vid="_vid/masks_est_%d"%i, (torch.stack([make_grid(rearrange(y.clip(0,250),"m 1 x y c -> m c x y").detach()) for y in x])).to(torch.uint8).numpy()
    #        #torchvision.io.write_video("output/img/"+vid_name+".mp4",torch.from_numpy(vid).permute(0,2,3,1),4)
    #        #wandb.log({vid_name:wandb.Video(vid, fps=4,format="mp4") })

    if "zoe_depthvis" in model_output: wandb_out["est/zoe_depth"]=make_grid(model_output["zoe_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "res_depthvis" in model_output: wandb_out["est/res_depth"]=make_grid(model_output["res_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "depthvis" in model_output: wandb_out["est/depth"]=make_grid(model_output["depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    #if "depth_permaskvis" in model_output: 
    #    wandb_out["est/depth_permask"]=make_grid(model_output["depth_permaskvis"].cpu().flatten(0,2).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=model_output["rig_masks"].size(2))

    if "not_sky" in model_input: wandb_out["ref/not_sky"]=make_grid(model_input["not_sky"].cpu().flatten(0,1).detach().float(),nrow=nrow,normalize=False)
    #if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(ch_fst(model_output["corr_weights"],resolution[1]).flatten(0,1).cpu().detach(),normalize=True,nrow=nrow-1)
    if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(model_output["corr_weights"].flatten(0,1).cpu().detach(),normalize=True,nrow=nrow)
    if "bwd_flow" in model_input: wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(model_input["bwd_flow"].flatten(0,1),nrow=nrow))/255

    if "warp_rgbs_" in model_output: wandb_out["est/warp_rgb"]= make_grid(model_output["warp_rgbs_"].flatten(0,2).permute(0,2,1).unflatten(-1,imsl),nrow=nrow)

    if "flow_from_pose" in model_output and not torch.isnan(model_output["flow_from_pose"]).any(): 
        wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255

    if "flow_from_pose_uncomp" in model_output:
        wandb_out["est/flow_est_pose_uncomp"] = flow_vis_torch.flow_to_color(make_grid(rearrange(model_output["flow_from_pose_uncomp"],"m bt c x y -> (bt m) c x y").detach(),
                                                                                nrow=len(model_output["flow_from_pose_uncomp"])))/255
        wandb_out["est/flow_est_pose_uncomp_err"] = make_grid(rearrange((model_input["bwd_flow"]-model_output["flow_from_pose_uncomp"]).abs().sum(dim=2), "m bt x y -> (bt m) 1 x y").detach(),
                nrow=len(model_output["flow_from_pose_uncomp"]))

    if "rig_poses" in model_output:
        rig_colors=torch.rand(len(model_output["rig_poses"].cpu()[0]),3)
        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        for color,our_pos in zip(rig_colors,model_output["rig_poses"][0,:,:,:3,-1].detach().cpu()):
            ax.plot(*our_pos.cpu().T.numpy(),c=color.numpy())
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/rig_poses"] = pose_img
 
        # point cloud plotting
        for color_i in range(3):
            for view_i in range(3):
                if 1:continue
                pts_all = model_output["pts_canon"][0].cpu()
                rig_selection = model_output["rig_masks"][0].max(dim=1)[1].cpu()
                pts = torch.gather(pts_all.permute(1,2,3,0), -1, rig_selection.expand(-1,-1,3).unsqueeze(-1)).squeeze(-1)
                mask=filter_points(pts.flatten(0,1),.4)

                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')

                # Plot camera frustums for each pose
                cmap = cm.get_cmap('viridis', model_output["rig_poses"].cpu()[0].size(1))
                for rig_i,rig_poses in enumerate(model_output["rig_poses"].cpu()[0]):
                    for pose_i,pose in enumerate(rig_poses):
                        if view_i>1:continue
                        frustum_points = torch.tensor([[0, 0, 0, 1], [-1, -1, 2, 1], [1, -1, 2, 1], [1, 1, 2, 1], [-1, 1, 2, 1]]).float()  
                        # Scale frustum size
                        frustum_points[:, :3] *= .05
                        # Transform frustum points to world space using the pose matrix
                        frustum_world = torch.mm(pose, frustum_points.T).T
                        # Extract the corners of the frustum
                        camera_center = frustum_world[0, :3]
                        bottom_left = frustum_world[1, :3]
                        bottom_right = frustum_world[2, :3]
                        top_right = frustum_world[3, :3]
                        top_left = frustum_world[4, :3]

                        # Plot frustum lines
                        frustum_lines = [[camera_center, bottom_left], [camera_center, bottom_right], [camera_center, top_right], [camera_center, top_left],
                                         [bottom_left, bottom_right], [bottom_right, top_right], [top_right, top_left], [top_left, bottom_left]]

                        color=rig_colors[rig_i] if 1 else cmap(frame_i)

                        for line in frustum_lines:
                            ax.plot(*zip(*line), color=color.numpy(),zorder=10,alpha=.9)

                rgb = model_input["rgb"].cpu()[0].flatten(-2,-1).permute(0,2,1)
                stride=100

                pt_colors = (torch.from_numpy(cmap(list(range(len(pts)))))[:,None].expand(-1,pts.size(1),-1).flatten(0,1), # frame color
                             rig_colors[rig_selection.flatten(0,1).squeeze(-1)], # rig color
                             rgb.flatten(0,1).clip(-1,1)*.5+.5, #rgb color
                            )[color_i]
                sc = ax.scatter(*pts.flatten(0,1)[mask][::stride].T, c=pt_colors[mask][::stride],marker='o', s=1)
                ax.view_init(elev=-15 if view_i==0 else -90 if view_i==1 else -90, azim=-90)

                fig.tight_layout()
                fig.canvas.draw()
                image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
                plt.close()
                view_name=["bev","front","front_nopose"]
                color_name=["frame","rig","rgb"]
                wandb_out["vis3d/view_%s_color_%s"%(view_name[view_i],color_name[color_i],)] = pose_img

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        for suffix in ["_aligned",""]:
            pose_imgs=[]
            poses = model_output["poses"].unsqueeze(1) if "intermed_poses_" not in model_output else model_output["intermed_poses_"]
            #try:
            for i in range(len(poses)):
                for j in range(len(poses[0])):
                    #our_pos=poses[i,j,:,0,:3,-1].cpu()
                    our_pos=poses[i,j,:,:3,-1].detach().cpu()
                    if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
                    else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

                    pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
                    ate=(pos_gt_-pos_est_).square().mean().sqrt()
                    print("ATE:", ate)
                    wandb.log({"metrics/ATE": ate},step=step)

                    fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
                    ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
                    ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
                    # Set the same limits for all axes
                    min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
                    ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
                    plt.legend();plt.tight_layout();fig.canvas.draw()
                    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                    pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
                    plt.close()
                    wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)
        # pretty plot for figures
        our_pos=model_output["poses"][0,:,:3,-1].detach().cpu()
        gt_pos=model_input["c2w"][0,:,:3,-1].detach().cpu()
        gt_pos,our_pos = [torch.from_numpy(x) for x in spatial.procrustes(gt_pos.cpu().numpy(),our_pos.numpy())[:2]]

        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        ax.plot(*our_pos.cpu().T.numpy(),c="red",label="Estimated Trajectory");
        ax.plot(*gt_pos.cpu().T.numpy(),c="black",label="GT Trajectory");
        # Set the same limits for all axes
        #min_,max_=gt_pos.min(),gt_pos.max()
        #ax.set_xlim(min_, max_)
        #ax.set_ylim(min_, max_)
        #ax.set_zlim(min_, max_)
        #plt.legend();
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/pose_est_pretty"] = pose_img

    #except: print("pose plotting error") 

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        #zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    return wandb_out

def pose_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
    # Log points and boxes in W&B
    point_scene = wandb.Object3D({
        "type": "lidar/beta",
        "points":  model_output["poses"][:,:3,-1].cpu().numpy(),
    })
    wandb.log({"camera positions": point_scene})


    
