import os
import geometry
import wandb
from matplotlib import cm
import cv2
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy()); 
from einops import rearrange, repeat
import models
import piqa
import imageio
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import viser.transforms as tf
from sklearn.ensemble import IsolationForest

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

# im pretty sure this is wrong, we need to filter based on nearest neighbor distances, not whatever this is, this is just contracting to the mean 3d point
# lets go back to that isolation outlier fit
def filter_points(points, iqr_factor=.3):
    iso_forest = IsolationForest(contamination=max(.01,min(iqr_factor,.49))).fit(points)
    outlier_scores = iso_forest.decision_function(points)
    return iso_forest.predict(points) != -1

    """Filters outliers from a 3D point cloud using the IQR method."""
    # Calculate quantiles
    q1 = torch.quantile(points, 0.25, dim=0)
    q3 = torch.quantile(points, 0.75, dim=0)
    iqr = q3 - q1

    # Calculate lower and upper bounds
    lower_bound = q1 - iqr_factor * iqr
    upper_bound = q3 + iqr_factor * iqr

    # Filter outliers
    mask = torch.all((points >= lower_bound) & (points <= upper_bound), dim=1)
    return mask

# serve todos: add point filtering callback 
def viser_update(server,loss_, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,wandb_imgs=None):
    # Losses
    #print("doing viser")
    loss=defaultdict(list)
    for x in loss_:
        for k,v in x.items():
            loss[k].append(v)
    def make_loss_fig(k,loss):
        fig = px.line( y=loss, x=list(range(len(loss))), labels={"x": "x", "y": k}, title=k)
        fig.layout.title.automargin = True  
        fig.update_layout( margin=dict(l=20, r=20, t=20, b=20),) 
        return fig

    if step%10==0:
        loss_figs = {k:make_loss_fig(k,v) for k,v in loss.items()} 

    # Point cloud
    #store world crds in frame 0, then transform with einsum to any frame with slider 
    #ros, rds = geometry.get_world_rays(model_input["x_pix"],model_input["intrinsics"], model_output["poses"])
    #scene["world_crds"]=ros+rds*model_output["depth"].cuda()
    scene=model_output
    
    if step==0:
        server.loss_plots=[]
        server.cam_set=False
        for k,v in loss_figs.items():
            server.loss_plots.append(server.gui.add_plotly(figure=v))
        with server.gui.add_folder("Playback"):
            server.gui_timestep = server.gui.add_slider( "Timestep", min=0, max=len(model_input["rgb"])- 1, step=1, initial_value=0, disabled=True,)
            server.gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
            server.gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
            server.gui_playing = server.gui.add_checkbox("Playing", True)
            server.gui_color_choice = server.gui.add_dropdown(label="Color Vis", options=["rgb","lie_rot","lie_trans"], initial_value="rgb")
            server.gui_framerate = server.gui.add_slider( "FPS", min=.1, max=10, step=0.1, initial_value=.5)
            server.gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60"))

        @server.gui_next_frame.on_click
        def _(_) -> None: server.gui_timestep.value = (server.gui_timestep.value + 1) % len(model_input["rgb"])
        @server.gui_prev_frame.on_click
        def _(_) -> None: server.gui_timestep.value = (server.gui_timestep.value - 1) % len(model_input["rgb"])
        @server.gui_playing.on_update
        def _(_) -> None:
            server.gui_timestep.disabled = server.gui_playing.value
            server.gui_next_frame.disabled = server.gui_playing.value
            server.gui_prev_frame.disabled = server.gui_playing.value
        @server.gui_framerate_options.on_click
        def _(_) -> None: server.gui_framerate.value = int(server.gui_framerate_options.value)
        @server.gui_timestep.on_update
        def _(_) -> None:
            current_timestep = server.gui_timestep.value
            with server.atomic(): 
                for client in server.get_clients().values(): client.camera.position = client.camera.position+(0., 0., 0.) # trigger rerender
            server.flush()  # Optional!
        server.time_until_next_render=time.time()
    else:
        if step%10==0 and step:
            for i,(k,v) in enumerate(loss_figs.items()):
                server.loss_plots[i].figure=v
    # add wandb imgs
    if wandb_imgs is not None and 1:
        if step==0:server.wandb_plots={}
        for k,v in {k:(v.detach().cpu()*255).permute(1,2,0).to(torch.uint8).numpy() for k,v in wandb_imgs.items()}.items():
            fig = px.imshow(v)
            if k not in server.wandb_plots: server.wandb_plots[k]=server.gui.add_plotly(figure=fig)
            else:server.wandb_plots[k].figure=fig
    # Move to timestep i and rerender for vis
    trigger_rerender=False
    if server.time_until_next_render<time.time():
        #print("do rendering")
        if server.gui_playing.value: server.gui_timestep.value = (server.gui_timestep.value + 1) % len(model_input["rgb"])
        server.time_until_next_render=time.time()+ 1.0 / server.gui_framerate.value
        trigger_rerender=True
        #print("do rendering")
    if step%1000==0 or trigger_rerender:
        for client in server.get_clients().values(): client.camera.position = client.camera.position+(0., 0., 0.) # trigger rerender

    #print("done viser")

def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[]):

    resolution = list(model_input["rgb"].flatten(0,1).permute(0,2,3,1).shape)
    resolution[0]=ground_truth["rgb"].size(1)*ground_truth["rgb"].size(0)
    nrow=model_input["rgb"].size(1)
    imsl=model_input["rgb"].shape[-2:]
    inv = lambda x : 1/(x+1e-8)

    #if "rig_masks" in model_output:model_output["depth_permask"] = model_output["depth"].unsqueeze(-3)*model_output["rig_masks"]

    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    wandb_out["ref/rgb_gt"]= make_grid(ground_truth["rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)

    if "poses_lie" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["poses_lie"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["poses_lie"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "lie_perpix" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    

    if "rig_masks" in model_output:
        #wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(),nrow=model_output["rig_masks"].size(2))
        low_res=(64,64)
        wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=low_res[0]).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_corr_weighted"]= make_grid((F.interpolate(model_output["corr_weights"][0],low_res).unsqueeze(2)*ch_fst(model_output["rig_masks"][0,1:],low_res[0])
                                                                    ).flatten(0,1).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_corr_weighted_rgb"]= make_grid((F.interpolate(model_input["rgb"][0,1:],low_res).unsqueeze(1)*ch_fst(model_output["rig_masks"][0,1:],low_res[0])*
                                                                    F.interpolate(model_output["corr_weights"][0],low_res).unsqueeze(2)
                                                                    ).flatten(0,1).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_rgb"]= make_grid(rearrange(model_output["rig_masks"].flatten(0,1)*(F.interpolate(model_input["rgb"].flatten(0,1),low_res).flatten(-2,-1).permute(0,2,1).unsqueeze(1)*.5+.5),
                                                "bt o (x y) c -> (bt o) c x y",x=low_res[0]).detach(),nrow=model_output["rig_masks"].size(2))
    if "rig_masks_unnorm" in model_output:
        #wandb_out["est/rig_masks_unnorm"]= make_grid(rearrange(model_output["rig_masks_unnorm"],"b t o (x y) 1 -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(),nrow=model_output["rig_masks"].size(2))
        wandb_out["est/rig_masks_unnorm"]= make_grid(rearrange(model_output["rig_masks_unnorm"],"b t o (x y) 1 -> (b t o) 1 x y",x=64).detach(),nrow=model_output["rig_masks"].size(2))

    if "level_scores" in model_output:
        wandb_out["est/level_scores"]= make_grid(rearrange(model_output["level_scores"],"b t l x y -> (b t l) 1 x y",).detach(), nrow=3)
    if "composite_weights" in model_output:
        for i,x in enumerate(model_output["composite_weights"]):
            wandb_out["est/composite_weights_%i"%i]= make_grid(rearrange(x,"b t o (x y) -> (b t o) 1 x y",x=model_input["rgb"].size(-2)).detach(), nrow=x.size(2))
    if "mask_spatial_scores" in model_output: 
        for i,x in enumerate(model_output["mask_spatial_scores"]):
            wandb_out["est/mask_spatial_scores_%i"%i]= make_grid(rearrange(x,"b t o x y -> (b t o) 1 x y").detach(), nrow=x.size(2))
    if "is_static" in model_output:
        wandb_out["est/is_static"]= make_grid(rearrange(model_output["is_static"],"b t 1 x y -> (b t) 1 x y").detach(), nrow=model_output["is_static"].size(1))

    if "traj_top_comps" in model_output:
        wandb_out["est/top_comps1"]= make_grid(model_output["traj_top_comps"][:,:3].abs().detach(),normalize=False)
        wandb_out["est/top_comps2"]= make_grid(model_output["traj_top_comps"][:,3:].abs().detach(),normalize=False)

    if "affinity_sim" in model_output:
        wandb_out["est/affinity_sim"] = make_grid(rearrange(model_output["affinity_sim"][0,0],"x1 y1 x y -> (x1 y1) 1 x y").detach(), nrow=model_output["affinity_sim"].size(3))
    if "affinity_emb" in model_output:
        # pca vis
        from featup.util import pca
        features=rearrange(model_output["affinity_emb"].flatten(0,1),"bt c x y -> 1 c (bt x) y")
        B, C, H, W = features.shape
        if 0:
            features = features.view(B, C, -1)
            # Center the data
            features_mean = features.mean(dim=2, keepdim=True)
            features = features - features_mean
            covariance = torch.bmm(features, features.transpose(1, 2)) / (H * W - 1)
            # Perform SVD
            U, S, V = torch.svd(covariance)
            # Project the data onto the top principal components
            num_components=3
            transformed_features = torch.bmm(U[:, :, :num_components].transpose(1, 2), features)
            # Reshape back to original spatial dimensions
            transformed_features = transformed_features.view(B, num_components, H, W)
        else:
            transformed_features = pca(model_output["affinity_emb"].flatten(0,1)[None])[0][0]
        #transformed_features = F.normalize(transformed_features)*.5+.5
        wandb_out["est/affinity_emb"]= make_grid(transformed_features.detach(), nrow=model_output["affinity_emb"].size(1))

    for k,v in model_input.items():
        if "all_mask" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v,"t m x y -> (t m) 1 x y").detach(),nrow=v.size(1))
    for k,v in model_input.items():
        if "masks_vis" not in k:continue
        wandb_out["ref/%s"%k]= make_grid(rearrange(v/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=v.size(1))
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #if "all_masks" in model_input:
    #    for i,x in enumerate(model_input["all_masks"]):
    #        wandb_out["ref/all_masks_%d"%i]= make_grid(rearrange(x,"t m x y -> (t m) 1 x y").detach(),nrow=nrow)
    #if "masks_vis" in model_input:
    #    for i,x in enumerate(model_input["masks_vis"]):
    #        wandb_out["ref/masks_vis_%d"%i]= make_grid(rearrange(x/255,"t m 1 x y c -> (t m) c x y").detach(),nrow=nrow)
    #        #vid_name,vid="_vid/masks_est_%d"%i, (torch.stack([make_grid(rearrange(y.clip(0,250),"m 1 x y c -> m c x y").detach()) for y in x])).to(torch.uint8).numpy()
    #        #torchvision.io.write_video("output/img/"+vid_name+".mp4",torch.from_numpy(vid).permute(0,2,3,1),4)
    #        #wandb.log({vid_name:wandb.Video(vid, fps=4,format="mp4") })

    if "depth_inpvis" in model_output: wandb_out["est/depth_inp"]=make_grid(model_output["depth_inpvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "res_depthvis" in model_output: wandb_out["est/res_depth"]=make_grid(model_output["res_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    if "depthvis" in model_output: wandb_out["est/depth"]=make_grid(model_output["depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
    #if "depth_permaskvis" in model_output: 
    #    wandb_out["est/depth_permask"]=make_grid(model_output["depth_permaskvis"].cpu().flatten(0,2).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=model_output["rig_masks"].size(2))

    if "not_sky" in model_input: wandb_out["ref/not_sky"]=make_grid(model_input["not_sky"].cpu().flatten(0,1).detach().float(),nrow=nrow,normalize=False)
    #if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(ch_fst(model_output["corr_weights"],resolution[1]).flatten(0,1).cpu().detach(),normalize=True,nrow=nrow-1)
    if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(model_output["corr_weights"].flatten(0,1).cpu().detach(),normalize=True,nrow=nrow)
    if "bwd_flow" in model_input: wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(model_input["bwd_flow"].flatten(0,1),nrow=nrow))/255

    if "warp_rgbs_" in model_output: wandb_out["est/warp_rgb"]= make_grid(model_output["warp_rgbs_"].flatten(0,2).permute(0,2,1).unflatten(-1,imsl),nrow=nrow)

    if "flow_from_pose" in model_output and not torch.isnan(model_output["flow_from_pose"]).any(): 
        wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255

    if "flow_from_pose_uncomp" in model_output:
        wandb_out["est/flow_est_pose_uncomp"] = flow_vis_torch.flow_to_color(make_grid(rearrange(model_output["flow_from_pose_uncomp"],"m bt c x y -> (bt m) c x y").detach(),
                                                                                nrow=len(model_output["flow_from_pose_uncomp"])))/255
        wandb_out["est/flow_est_pose_uncomp_err"] = make_grid(rearrange((model_input["bwd_flow"]-model_output["flow_from_pose_uncomp"]).abs().sum(dim=2), "m bt x y -> (bt m) 1 x y").detach(),
                nrow=len(model_output["flow_from_pose_uncomp"]))

    if "rig_poses" in model_output:
        rig_colors=torch.rand(len(model_output["rig_poses"].cpu()[0]),3)
        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        for color,our_pos in zip(rig_colors,model_output["rig_poses"][0,:,:,:3,-1].detach().cpu()):
            ax.plot(*our_pos.cpu().T.numpy(),c=color.numpy())
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/rig_poses"] = pose_img
 
        # point cloud plotting
        for color_i in range(3):
            for view_i in range(3):
                if 1:continue
                pts_all = model_output["pts_canon"][0].cpu()
                rig_selection = model_output["rig_masks"][0].max(dim=1)[1].cpu()
                pts = torch.gather(pts_all.permute(1,2,3,0), -1, rig_selection.expand(-1,-1,3).unsqueeze(-1)).squeeze(-1)
                mask=filter_points(pts.flatten(0,1),.4)

                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')

                # Plot camera frustums for each pose
                cmap = cm.get_cmap('viridis', model_output["rig_poses"].cpu()[0].size(1))
                for rig_i,rig_poses in enumerate(model_output["rig_poses"].cpu()[0]):
                    for pose_i,pose in enumerate(rig_poses):
                        if view_i>1:continue
                        frustum_points = torch.tensor([[0, 0, 0, 1], [-1, -1, 2, 1], [1, -1, 2, 1], [1, 1, 2, 1], [-1, 1, 2, 1]]).float()  
                        # Scale frustum size
                        frustum_points[:, :3] *= .05
                        # Transform frustum points to world space using the pose matrix
                        frustum_world = torch.mm(pose, frustum_points.T).T
                        # Extract the corners of the frustum
                        camera_center = frustum_world[0, :3]
                        bottom_left = frustum_world[1, :3]
                        bottom_right = frustum_world[2, :3]
                        top_right = frustum_world[3, :3]
                        top_left = frustum_world[4, :3]

                        # Plot frustum lines
                        frustum_lines = [[camera_center, bottom_left], [camera_center, bottom_right], [camera_center, top_right], [camera_center, top_left],
                                         [bottom_left, bottom_right], [bottom_right, top_right], [top_right, top_left], [top_left, bottom_left]]

                        color=rig_colors[rig_i] if 1 else cmap(frame_i)

                        for line in frustum_lines:
                            ax.plot(*zip(*line), color=color.numpy(),zorder=10,alpha=.9)

                rgb = model_input["rgb"].cpu()[0].flatten(-2,-1).permute(0,2,1)
                stride=100

                pt_colors = (torch.from_numpy(cmap(list(range(len(pts)))))[:,None].expand(-1,pts.size(1),-1).flatten(0,1), # frame color
                             rig_colors[rig_selection.flatten(0,1).squeeze(-1)], # rig color
                             rgb.flatten(0,1).clip(-1,1)*.5+.5, #rgb color
                            )[color_i]
                sc = ax.scatter(*pts.flatten(0,1)[mask][::stride].T, c=pt_colors[mask][::stride],marker='o', s=1)
                ax.view_init(elev=-15 if view_i==0 else -90 if view_i==1 else -90, azim=-90)

                fig.tight_layout()
                fig.canvas.draw()
                image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
                plt.close()
                view_name=["bev","front","front_nopose"]
                color_name=["frame","rig","rgb"]
                wandb_out["vis3d/view_%s_color_%s"%(view_name[view_i],color_name[color_i],)] = pose_img

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        for suffix in ["_aligned",""]:
            pose_imgs=[]
            poses = model_output["poses"].unsqueeze(1) if "intermed_poses_" not in model_output else model_output["intermed_poses_"]
            #try:
            for i in range(len(poses)):
                for j in range(len(poses[0])):
                    #our_pos=poses[i,j,:,0,:3,-1].cpu()
                    our_pos=poses[i,j,:,:3,-1].detach().cpu()
                    if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
                    else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

                    pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
                    ate=(pos_gt_-pos_est_).square().mean().sqrt()
                    print("ATE:", ate)
                    wandb.log({"metrics/ATE": ate},step=step)

                    fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
                    ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
                    ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
                    # Set the same limits for all axes
                    min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
                    ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
                    plt.legend();plt.tight_layout();fig.canvas.draw()
                    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                    pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
                    plt.close()
                    wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)
        # pretty plot for figures
        our_pos=model_output["poses"][0,:,:3,-1].detach().cpu()
        gt_pos=model_input["c2w"][0,:,:3,-1].detach().cpu()
        gt_pos,our_pos = [torch.from_numpy(x) for x in spatial.procrustes(gt_pos.cpu().numpy(),our_pos.numpy())[:2]]

        fig=plt.figure(figsize=(12,6));ax = fig.add_subplot(111, projection='3d');
        ax.plot(*our_pos.cpu().T.numpy(),c="red",label="Estimated Trajectory");
        ax.plot(*gt_pos.cpu().T.numpy(),c="black",label="GT Trajectory");
        # Set the same limits for all axes
        #min_,max_=gt_pos.min(),gt_pos.max()
        #ax.set_xlim(min_, max_)
        #ax.set_ylim(min_, max_)
        #ax.set_zlim(min_, max_)
        #plt.legend();
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.view_init(elev=55., azim=-95)
        fig.tight_layout()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        pose_img=torch.from_numpy(image_from_plot).permute(2,0,1)/255
        plt.close()
        wandb_out["est/pose_est_pretty"] = pose_img

    #except: print("pose plotting error") 

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        #zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    return wandb_out

def wandb_summary_splat(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[],view_i=0):
    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    wandb_out["img/rgb_gt"]= model_output["gt_rgb"].permute(2,0,1)
    wandb_out["img/rgb"]= model_output["rgb"].permute(2,0,1)
    wandb_out["img/depth"]= model_output["depthvis"].permute(2,0,1)
    nrow=1

    if "poses_lie" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["poses_lie"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["poses_lie"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "lie_perpix" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "is_static" in model_output:
        wandb_out["est/is_static"]= make_grid(rearrange(model_output["is_static"],"b t 1 x y -> (b t) 1 x y").detach(), nrow=model_output["is_static"].size(1))
    if "render_flow" in model_output: 
        wandb_out["est/render_flow"]= flow_vis_torch.flow_to_color(make_grid(model_output["render_flow"].permute(0,3,1,2),nrow=nrow))/255
        wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(model_input["flow_inp_"][[view_i-1]],nrow=nrow))/255

    #tmp=vis_scripts.flow_vis_torch.flow_to_color(render_flow[0][:,...,:2].permute(0,3,1,2))/255
    #tmp2=vis_scripts.flow_vis_torch.flow_to_color(model_input["flow_inp_"][[view_i-1]])/255

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        for suffix in ["_aligned",""]:
            pose_imgs=[]
            poses = model_output["poses"].unsqueeze(1) if "intermed_poses_" not in model_output else model_output["intermed_poses_"]
            #try:
            for i in range(len(poses)):
                for j in range(len(poses[0])):
                    #our_pos=poses[i,j,:,0,:3,-1].cpu()
                    our_pos=poses[i,j,:,:3,-1].detach().cpu()
                    if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
                    else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

                    pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
                    ate=(pos_gt_-pos_est_).square().mean().sqrt()
                    print("ATE:", ate)
                    wandb.log({"metrics/ATE": ate},step=step)
                    losses_agg.append({"ATE":ate.detach().item()})

                    fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
                    ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
                    ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
                    # Set the same limits for all axes
                    min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
                    ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
                    plt.legend();plt.tight_layout();fig.canvas.draw()
                    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                    pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
                    plt.close()
                    wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    return wandb_out

def pose_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
    # Log points and boxes in W&B
    point_scene = wandb.Object3D({
        "type": "lidar/beta",
        "points":  model_output["poses"][:,:3,-1].cpu().numpy(),
    })
    wandb.log({"camera positions": point_scene})
