import os,io,shutil
import geometry
import wandb
from matplotlib import cm
import cv2
from tqdm import tqdm
import torchvision
import time
from torchvision.utils import make_grid,draw_keypoints
import torch.nn.functional as F
import kornia
import numpy as np
import torch
import flow_vis
import flow_vis_torch
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import models
import piqa
import imageio
from PIL import Image
#import splines.quaternion
#from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
from scipy import spatial
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import viser.transforms as tf

ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")

def viser_update(server,loss_, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,wandb_imgs=None):
    # Losses
    loss=defaultdict(list)
    for x in loss_:
        for k,v in x.items():
            loss[k].append(v)
    def make_loss_fig(k,loss):
        fig = px.line( y=loss, x=list(range(len(loss))), labels={"x": "x", "y": k}, title=k)
        fig.layout.title.automargin = True  
        fig.update_layout( margin=dict(l=20, r=20, t=20, b=20),) 
        return fig

    if step%10==0:
        loss_figs = {k:make_loss_fig(k,v) for k,v in loss.items() if "depth" not in k} 

    scene=model_output

    # Make optimization summary rendering images and videos
    if wandb_imgs is not None and server.run.args.save_opt_vis: 
        savedir="output/img/opt_vis_%s/"%["flowmap","splatting"][server.run.args.splat_src is not None]
        if step==0:
            if os.path.exists(savedir): shutil.rmtree(savedir)
            os.makedirs(savedir)
        imsize=torch.tensor(model_input["bwd_flow"].shape[-2:])*2
        n_plots=len(loss_figs)
        imsize[0]=imsize[0]//n_plots*n_plots
        Ks=model_input["intrinsics"][0,:1,:3,:3].detach().clone()
        Ks[:,0]*=imsize[1]
        Ks[:,1]*=imsize[0]
        with torch.no_grad(): 

            view_pose = torch.eye(4).cuda(); 
            view_pose[:3,-1]=torch.tensor([0,0,2]).cuda(); 
            #view_pose[:3,-1]=torch.tensor([0,.3,2]).cuda(); 
            #view_pose[:3,:3]=kornia.geometry.conversions.axis_angle_to_rotation_matrix(torch.tensor([np.pi/12,0,0])[None]).cuda()[0]; 

            #poses = torch.eye(4)[None].expand(len(server.run.splat_vars["lie_poses"]),-1,-1).cuda()
            #poses[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(server.run.splat_vars["lie_poses"][...,:4])
            #poses[...,:3,-1] = server.run.splat_vars["lie_poses"][...,4:]

            # render out pan video every N iters
            if step%200==0 and step:#(step-1)%server.run.args.n_train_steps==0 and step:
                # todo smooth lie poses path (use spline)
                for pose_i,alpha_pose in enumerate(tqdm(F.interpolate(server.run.splat_vars["lie_perpix"].permute(0,2,1),(50),mode="linear").permute(2,0,1),
                    desc="rendering interpolation video",leave=False)):

                    interp_pose_perpix = torch.eye(4)[None].expand(len(alpha_pose),-1,-1).cuda()
                    interp_pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(alpha_pose[...,:4])
                    interp_pose_perpix[...,:3,-1] = alpha_pose[...,4:]
                    render_img= geometry.do_render(view_pose[None]@interp_pose_perpix.inverse(),0,imsize,Ks,server.run.splat_vars)[0][0,...,:3];
                    savepath=savedir+"render_vid_%04d.png"%pose_i
                    plt.imsave(savepath,render_img.cpu().numpy())
            
            # render out image of splat for summary
            render_img= geometry.do_render(view_pose,0,imsize,Ks,server.run.splat_vars)[0][0,...,:3];
            #plt.imsave("/home/cameronsmith/tmp.png",render_img.cpu().numpy())
        #loss_imgs = torch.cat([torch.from_numpy(np.array(Image.open(io.BytesIO(v.to_image(format="png", width=imsize[1]//2, height=imsize[0]//n_plots, scale=1)))))/255 for v in loss_figs.values()])
        loss_imgs = torch.cat([torch.from_numpy(np.array(Image.open(io.BytesIO(v.to_image(format="png", width=imsize[1]//2, height=imsize[0]//n_plots, scale=1)))))/255 for v in loss_figs.values()])[...,:3]
        #loss_imgs = F.interpolate(loss_imgs[None],(imsize[0]-loss_imgs.size(0),imsize[1]//2),mode="bilinear")[0].permute(1,2,0)
        if server.run.args.splat_src is None and "est/pose_est_aligned" in wandb_imgs:
            #pose_img = F.interpolate(wandb_imgs["est/pose_est_aligned"][None],(imsize[0]-loss_imgs.size(0),imsize[1]//2),mode="bilinear")[0].permute(1,2,0)
            pose_img = F.interpolate(wandb_imgs["est/pose_est_aligned"][None],(imsize[0],imsize[1]//2),mode="bilinear")[0].permute(1,2,0)
            vis_img = torch.cat((loss_imgs,pose_img),1)
        else: vis_img=loss_imgs
        vis_img = torch.cat((render_img.cpu(),vis_img),1)
        savepath=savedir+"%d_%d.png"%(int(time.time()),step)
        #print(savepath)
        plt.imsave(savepath,vis_img.numpy())
    
    if step==0:
        server.loss_plots=[]
        server.cam_set=False
        for k,v in loss_figs.items():
            server.loss_plots.append(server.gui.add_plotly(figure=v))
        with server.gui.add_folder("Playback"):
            server.gui_timestep = server.gui.add_slider( "Timestep", min=0, max=len(model_input["rgb"][0])- 1, step=1, initial_value=0, disabled=True,)
            server.gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
            server.gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
            server.gui_playing = server.gui.add_checkbox("Playing", True)
            server.gui_color_choice = server.gui.add_dropdown(label="Color Vis", options=["rgb","lie_rot","lie_trans"], initial_value="rgb")
            server.gui_framerate = server.gui.add_slider( "FPS", min=.1, max=10, step=0.1, initial_value=.5)
            server.gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60"))
            server.gui_point_size     = gui_point_size     = server.gui.add_slider( "Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01)
            server.points= server.run.splat_vars["means"].detach().cpu().numpy()
            #server.point_cloud        = point_cloud = server.scene.add_point_cloud( name="/flowmap/pcd", points=server.points, colors=server.run.splat_vars["colors"].detach().cpu().numpy(), point_size=server.gui_point_size.value,)

        @server.gui_point_size.on_update
        def _(_) -> None: server.point_cloud.point_size = server.gui_point_size.value
        @server.gui_next_frame.on_click
        def _(_) -> None: server.gui_timestep.value = (server.gui_timestep.value + 1) % len(model_input["rgb"][0])
        @server.gui_prev_frame.on_click
        def _(_) -> None: server.gui_timestep.value = (server.gui_timestep.value - 1) % len(model_input["rgb"][0])
        @server.gui_playing.on_update
        def _(_) -> None:
            server.gui_timestep.disabled = server.gui_playing.value
            server.gui_next_frame.disabled = server.gui_playing.value
            server.gui_prev_frame.disabled = server.gui_playing.value
        @server.gui_framerate_options.on_click
        def _(_) -> None: server.gui_framerate.value = int(server.gui_framerate_options.value)
        @server.gui_timestep.on_update
        def _(_) -> None:
            current_timestep = server.gui_timestep.value
            with server.atomic(): 
                for client in server.get_clients().values(): client.camera.position = client.camera.position+(0., 0., 0.) # trigger rerender
            server.flush()  # Optional!

            #pose_perpix = torch.eye(4)[None,None].expand(*splat_vars["lie_perpix"].shape[:2],-1,-1).cuda()
            #pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(splat_vars["lie_perpix"][...,:4])
            #pose_perpix[...,:3,-1] = splat_vars["lie_perpix"][...,4:]
            #means_i=torch.einsum("kij,kj->ki",pose_perpix,hom(splat_vars["means"]))[...,:3]
            #server.point_cloud.points=means_i.detach().cpu().numpy()
            #server.point_cloud.colors=colors=server.run.splat_vars["colors"].detach().cpu().numpy()
            #server.point_cloud.point_size = server.gui_point_size.value

            # todo refactor this into geometry
            #print("modifying pc",server.run.splat_vars["means"].median())
            #pose_perpix = torch.eye(4)[None,None].expand(*server.run.splat_vars["lie_perpix"].shape[:2],-1,-1).cuda()
            #pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(server.run.splat_vars["lie_perpix"][...,:4])
            #pose_perpix[...,:3,-1] = server.run.splat_vars["lie_perpix"][...,4:]
            #means_i=torch.einsum("kij,kj->ki",pose_perpix[:,server.gui_timestep.value],models.hom(server.run.splat_vars["means"]))[...,:3]
            #print("adding pc")
            #server.point_cloud.remove()
            #server.point_cloud        = point_cloud = server.scene.add_point_cloud( name="/flowmap/pcd", points=means_i.detach().cpu().numpy(), colors=server.run.splat_vars["colors"].detach().cpu().numpy(), point_size=server.gui_point_size.value,)
            #server.point_cloud.points=means_i[:10].detach().cpu().numpy()#*100
            #server.point_cloud.colors=colors=server.run.splat_vars["colors"].detach().cpu().numpy()
            ##server.point_cloud.point_size = server.gui_point_size.value
            #for client in server.get_clients().values(): client.camera.position = client.camera.position+(0., 0., 0.) # trigger rerender

        server.time_until_next_render=time.time()
        server.frames=[]

    else:
        if step%10==0 and step:
            for i,(k,v) in enumerate(loss_figs.items()):
                server.loss_plots[i].figure=v
            for frame in server.frames:frame.remove()
            server.frames=[]
            if "poses" in model_output:
                for img_id in range(model_output["poses"].size(1)):
                    T_world_camera = tf.SE3.from_matrix(model_output["poses"][0,img_id].detach().cpu().inverse().numpy())
                    server.frames.append( server.scene.add_frame( f"/colmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,) )
                    frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", fov=1 , aspect=1, scale=0.05,color=[255,0,0])

            # add wandb imgs
    if wandb_imgs is not None and 1:
        if step==0:server.wandb_plots={}
        for k,v in {k:(v.detach().cpu()*255).permute(1,2,0).to(torch.uint8).numpy() for k,v in wandb_imgs.items()}.items():
            fig = px.imshow(v)
            if k not in server.wandb_plots: server.wandb_plots[k]=server.gui.add_plotly(figure=fig)
            else:server.wandb_plots[k].figure=fig
    # Move to timestep i and rerender for vis
    #trigger_rerender=False
    #if server.time_until_next_render<time.time():
    #    if server.gui_playing.value: server.gui_timestep.value = (server.gui_timestep.value + 1) % len(model_input["rgb"][0])
    #    server.time_until_next_render=time.time()+ 1.0 / server.gui_framerate.value
    #    trigger_rerender=True
    #if step%1000==0 or trigger_rerender:
    #    for client in server.get_clients().values(): client.camera.position = client.camera.position+(0., 0., 0.) # trigger rerender


def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[]):

    resolution = list(model_input["rgb"].flatten(0,1).permute(0,2,3,1).shape)
    resolution[0]=ground_truth["rgb"].size(1)*ground_truth["rgb"].size(0)
    nrow=model_input["rgb"].size(1)
    imsl=model_input["rgb"].shape[-2:]
    inv = lambda x : 1/(x+1e-8)

    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    if step%50==0:
        wandb_out["ref/rgb_gt"]= make_grid(ground_truth["rgb"].cpu().flatten(0,1).detach()*.5+.5,nrow=nrow)

        if "lie_perpix" in model_output:
            rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
            trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
            wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
            wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
        
        if "rig_masks" in model_output:
            low_res=imsl#(64,64)
            wandb_out["est/rig_masks"]= make_grid(rearrange(model_output["rig_masks"],"b t o (x y) 1 -> (b t o) 1 x y",x=low_res[0]).detach(),nrow=model_output["rig_masks"].size(2))
            wandb_out["est/rig_masks_corr_weighted"]= make_grid((F.interpolate(model_output["corr_weights"][0],low_res).unsqueeze(2)*ch_fst(model_output["rig_masks"][0,1:],low_res[0])
                                                                        ).flatten(0,1).detach(),nrow=model_output["rig_masks"].size(2))
            wandb_out["est/rig_masks_corr_weighted_rgb"]= make_grid((F.interpolate(model_input["rgb"][0,1:],low_res).unsqueeze(1)*ch_fst(model_output["rig_masks"][0,1:],low_res[0])*
                                                                        F.interpolate(model_output["corr_weights"][0],low_res).unsqueeze(2)
                                                                        ).flatten(0,1).detach(),nrow=model_output["rig_masks"].size(2))
            wandb_out["est/rig_masks_rgb"]= make_grid(rearrange(model_output["rig_masks"].flatten(0,1)*(F.interpolate(model_input["rgb"].flatten(0,1),low_res).flatten(-2,-1).permute(0,2,1).unsqueeze(1)*.5+.5),
                                                    "bt o (x y) c -> (bt o) c x y",x=low_res[0]).detach(),nrow=model_output["rig_masks"].size(2))
        if "depth_inpvis" in model_output: wandb_out["est/depth_inp"]=make_grid(model_output["depth_inpvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
        if "res_depthvis" in model_output: wandb_out["est/res_depth"]=make_grid(model_output["res_depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
        if "depthvis" in model_output: wandb_out["est/depth"]=make_grid(model_output["depthvis"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)

        if "corr_weights" in model_output: wandb_out["est/corr_weights"] = make_grid(model_output["corr_weights"].flatten(0,1).cpu().detach(),normalize=True,nrow=nrow)
        if "bwd_flow" in model_input: wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(model_input["bwd_flow"].flatten(0,1),nrow=nrow))/255
        if "rig_flow_masks" in model_input: 
            wandb_out["ref/rig_flow_masks"]= make_grid(rearrange(model_input["rig_flow_masks"],"1 t o x y -> (t o) 1 x y"),nrow=model_input["rig_flow_masks"].size(2))
            #wandb_out["ref/rig_flow_masks"]= make_grid(model_input["rig_flow_masks"].flatten(0,1),nrow=nrow)
            #flow_comp = (model_input["rig_flow_masks"]*model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl)+
            #              (1-model_input["rig_flow_masks"])*model_input["bwd_flow"])
            #wandb_out["ref/flow_rig_comp"]= flow_vis_torch.flow_to_color(make_grid(flow_comp.flatten(0,1),nrow=nrow))/255
        if "flow_from_pose" in model_output and not torch.isnan(model_output["flow_from_pose"]).any(): 
            wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255

        if "dino_pca" in model_output:
            pass # todo add
            #wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].clip(-.1,.1).flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255
        if "dino_clusters" in model_output:
            wandb_out["est/dino_clusters"]= make_grid(rearrange(model_input["dino_clusters"],"o t (x y) 1 -> (t o) 1 x y",x=low_res[0]).detach(),nrow=model_output["rig_masks"].size(2))

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        suffix="_aligned"
        pose_imgs=[]
        poses = model_output["poses"].unsqueeze(1)
        i=0
        for j in range(len(poses[0])):
            our_pos=poses[i,j,:,:3,-1].detach().cpu()
            if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
            else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

            pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
            ate=(pos_gt_-pos_est_).square().mean().sqrt()
            print("ATE:", ate)
            wandb.log({"metrics/ATE": ate},step=step)

            fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
            ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
            ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
            min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
            ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
            plt.legend();plt.tight_layout();fig.canvas.draw()
            image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
            plt.close()
            wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)
    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    return wandb_out

def wandb_summary_splat(loss, model_output, model_input, ground_truth, resolution,prefix="",suffix="",step=0,losses_agg=[],view_i=0):
    # Convert depths to colormapped 3-channel images:
    for k,v in list(model_output.items()): # magma colormap for depth -- todo change to depth_colored instead of depth to avoid ambiguity
        if type(v)!=list and len(v.shape): v=v.clip(min=.3)
        if "depth" in k: model_output[k+"_raw"] = v
        if "depth" in k and "raw" not in k: model_output[k+"vis"] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]

    wandb_out = {}

    wandb_out["img/rgb_gt"]= model_output["gt_rgb"]*.5+.5#.permute(2,0,1)
    wandb_out["img/rgb"]= model_output["rgb"].permute(2,0,1)
    wandb_out["img/depth"]= model_output["depthvis"].permute(2,0,1)
    nrow=1

    if "poses_lie" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["poses_lie"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["poses_lie"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "lie_perpix" in model_output:
        rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(model_output["lie_perpix"][...,:4]).flatten(0,1).permute(0,3,1,2)*.5+.5
        trans_vis = model_output["lie_perpix"][...,-3:].flatten(0,1).permute(0,3,1,2)/5+.5
        wandb_out["est/poses_lie_rot_perpix"]= make_grid(rot_vis.detach(),nrow=nrow,normalize=False)
        wandb_out["est/poses_lie_trans_perpix"]= make_grid(trans_vis.detach(),nrow=nrow,normalize=False)
    if "is_static" in model_output:
        wandb_out["est/is_static"]= make_grid(rearrange(model_output["is_static"],"b t 1 x y -> (b t) 1 x y").detach(), nrow=model_output["is_static"].size(1))
    if "render_flow" in model_output: 
        wandb_out["est/render_flow"]= flow_vis_torch.flow_to_color(make_grid(model_output["render_flow"].permute(0,3,1,2),nrow=nrow))/255
        wandb_out["ref/flow_gt_bwd"]= flow_vis_torch.flow_to_color(make_grid(
            model_input["flow_inp_"][0,[view_i-1]] if view_i else torch.zeros_like(model_input["flow_inp_"][0,[0]])
            ,nrow=nrow))/255

    if "c2w" in model_input and 1: # plot estimated poses against GT 
        suffix="_aligned"
        pose_imgs=[]
        poses = model_output["poses"].unsqueeze(1)
        i=0
        for j in range(len(poses[0])):
            #our_pos=poses[i,j,:,:3,-1].detach().cpu()
            our_pos=model_output["lie_poses"][...,4:].detach().cpu()
            if len(suffix): gt_pos=geometry.numpy_procrustes(our_pos,(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[:,:,:3,-1].cpu()[i])[1]
            else: gt_pos=(model_input["c2w"][:,[0]].inverse() @ model_input["c2w"])[i,:,:3,-1].cpu()

            pos_gt_,pos_est_ = [torch.from_numpy(x) for x in spatial.procrustes(model_input["c2w"][i,:,:3,-1].cpu().numpy(),our_pos.numpy())[:2]]
            ate=(pos_gt_-pos_est_).square().mean().sqrt()
            print("ATE:", ate)
            wandb.log({"metrics/ATE": ate},step=step)

            fig=plt.figure();ax = fig.add_subplot(111, projection='3d');
            ax.plot(*our_pos.cpu().unbind(1),c="red",label="Estimated Trajectory");
            ax.plot(*gt_pos.cpu().unbind(1),c="black",label="GT Trajectory");
            min_,max_=min(our_pos.min(),gt_pos.min()),max(our_pos.max(),gt_pos.max())
            ax.set_xlim(min_, max_);ax.set_ylim(min_, max_);ax.set_zlim(min_, max_)
            plt.legend();plt.tight_layout();fig.canvas.draw()
            image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            pose_imgs.append(torch.from_numpy(image_from_plot).permute(2,0,1)/255)
            plt.close()
            wandb_out["est/pose_est"+suffix] = torch.cat(pose_imgs,2)

    if 0:
        for k,v in wandb_out.items(): print(k,v.max(),v.min())
        for k,v in wandb_out.items():
            print(k,v.shape)
            plt.imsave("output/img/%s.png"%k,v.float().permute(1,2,0).detach().cpu().numpy().clip(0,1));
        print("saving locally")
        zz

    wandb.log({prefix+k+suffix:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()})
    return wandb_out
