"""Record3D visualizer

Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
"""

import time
from pathlib import Path

import numpy as np
import tyro
from tqdm.auto import tqdm

import viser
import viser.extras
import viser.transforms as tf

import torch
import kornia
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)

def main(
    share: bool = False,
) -> None:
    server = viser.ViserServer()
    if share: server.request_share_url()

    scene=torch.load("/data/cameron/pose_exps/poses_metric_video_depthtrain_flamingo2.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_30frameexp518_74410_144191.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_30frameexp1037_30321_22281.pt",map_location="cpu")
    #scene=torch.load("output/pose_exps/poses_walking_tours_vis.pt",map_location="cpu")
    #scene=torch.load("output/pose_exps/poses_morerealrobot_testing_overfit.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_swan_redo.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_realdepthtesting2.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_30frameexp1037_30392_22995.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_30frameexp1037_30381_22807.pt",map_location="cpu")
    #scene=torch.load("/data/cameron/pose_exps/poses_robotics_redo.pt",map_location="cpu")
    #scene=torch.load("/home/cameronsmith/repos/point_track_sfm/output/pose_exps/poses_horns_rig_mask_testing.pt",map_location="cpu")
    num_frames = min(500,scene["poses_all"].size(2))

    # Add playback UI.
    with server.gui.add_folder("Playback"):
        gui_point_size = server.gui.add_slider( "Point size", min=0.001, max=0.1, step=1e-3, initial_value=0.03,)
        #gui_point_size = server.gui.add_slider( "Point size", min=0.001, max=0.1, step=1e-3, initial_value=0.005,)
        gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True,)
        gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
        gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
        gui_playing = server.gui.add_checkbox("Playing", True)
        gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=20)
        gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60"))
        color_options = server.gui.add_button_group( "Color vis options", ("rgb","trans","pose_clusters","aff_pca","rot"))

    # Frame step buttons.
    @gui_next_frame.on_click
    def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames
    @gui_prev_frame.on_click
    def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames
    @gui_playing.on_update
    def _(_) -> None:
        gui_timestep.disabled = gui_playing.value
        gui_next_frame.disabled = gui_playing.value
        gui_prev_frame.disabled = gui_playing.value
    @gui_framerate_options.on_click
    def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value)
    @color_options.on_click
    def _(_) -> None: 
        for i, point_node in enumerate(point_nodes):
            point_node.colors = ({"rot":rot_vis,"trans":trans_vis,"aff_pca":pca_aff_emb,"rgb":(scene["rgb_pertrack"][0]*.5+.5),"pose_clusters":pose_cluster_colors}[color_options.value][::stride]*255).int().numpy()

    prev_timestep = gui_timestep.value

    # Toggle frame visibility when the timestep slider changes.
    @gui_timestep.on_update
    def _(_) -> None:
        nonlocal prev_timestep
        current_timestep = gui_timestep.value
        with server.atomic():
            # Toggle visibility.
            frame_nodes[current_timestep].visible = True
            frame_nodes[prev_timestep].visible = False
        prev_timestep = current_timestep
        server.flush()  # Optional!

    # Make color images -- se3 and aff pca
    from einops import rearrange
    poses_lie = torch.cat((kornia.geometry.conversions.rotation_matrix_to_quaternion(scene["poses_all"][...,:3,:3],eps=1e-5),scene["poses_all"][...,:3,-1]),-1)[0,:,-1]
    rot_vis=kornia.geometry.conversions.quaternion_to_axis_angle(poses_lie[...,:4])*.5+.5
    trans_vis = poses_lie[...,-3:]/5+.5
    sl=42
    trans_vis-=trans_vis.min()
    trans_vis/=trans_vis.max()
    rot_vis-=rot_vis.min()
    rot_vis/=rot_vis.max()
    #rot_vis = rearrange(rot_vis.detach(),"(x y s) c -> s c x y ",y=sl,x=sl)[0]
    #trans_vis = rearrange(trans_vis.detach(),"(x y s) c -> s c x y ",y=sl,x=sl)[0]
    features = scene["aff_emb_pertrack"].permute(0,2,1)
    features_mean = features.mean(dim=2, keepdim=True)
    features = features - features_mean
    covariance = torch.bmm(features, features.transpose(1, 2)) / (features.size(-1) - 1)
    U, S, V = torch.svd(covariance)
    num_components=min(3,features.size(1))
    pca_aff_emb = torch.bmm(U[:, :, :num_components].transpose(1, 2), features).permute(0,2,1)[0]*.5+.5
    stride=1

    import sys;sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm")
    import geometry
    n_cluster=3
    pose_clusters, pose_labels = geometry.cluster_and_represent(scene["poses_all"][0],n_clusters=n_cluster,return_labels=True) # cluster poses
    colors=[[255,0,0],[0,255,0],[0,0,255],[255,255,0],[255,0,255],[255,0,0] ]
    #pose_cluster_colors=torch.from_numpy(pose_labels[:,None]).expand(-1,3)/n_cluster
    pose_cluster_colors=torch.stack([torch.tensor(x) for x in colors])[pose_labels]/255

    # make pose cluster plot
    if 0:
        import matplotlib.pyplot as plt
        fig=plt.figure();ax = fig.add_subplot(111, projection='3d'); 
        pos=pose_clusters[0,:,:3,-1].detach().cpu() 
        #gt_pos=torch.load("/data/cameron/monocular_ests/pets_dogs/1037_30381_22807/poses.pt")[:len(pos),:3,-1]
        gt_pos=torch.load("/data/cameron/monocular_ests/pets_dogs/518_74410_144191/poses.pt")[:len(pos),:3,-1]
        gt_pos=geometry.numpy_procrustes(pos,gt_pos)[1]
        #ax.plot(*pos.cpu().unbind(1),c=["red","blue","green"][0],label="est"); 
        #for i,pos in enumerate(pose_clusters[:,:,:3,-1].detach().cpu()): ax.plot(*pos.cpu().unbind(1),c=["red","blue","green"][i],label=str(i)); 
        #ax.set_xlim(pos.min(),pos.max());
        #ax.set_ylim(pos.min(),pos.max());
        #ax.set_zlim(pos.min(),pos.max());
        ax.set_xlim(pose_clusters[:,:,:3,-1].min(),pose_clusters[:,:,:3,-1].max());
        ax.set_ylim(pose_clusters[:,:,:3,-1].min(),pose_clusters[:,:,:3,-1].max());
        ax.set_zlim(pose_clusters[:,:,:3,-1].min(),pose_clusters[:,:,:3,-1].max());
        for i,pos in enumerate(pose_clusters[:,:,:3,-1].detach().cpu()): ax.plot(*pos.cpu().unbind(1),c=["red","blue","green"][i],label="SE3 Cluster "+str(i)); 
        ax.plot(*gt_pos.cpu().unbind(1),c="gray",label="COLMAP"); 
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.get_zaxis().set_visible(False)
        plt.legend();plt.tight_layout();plt.savefig("/home/cameronsmith/tmp.png")


    rgb_crds = {"rot":rot_vis,"trans":trans_vis,"aff_pca":pca_aff_emb,"rgb":(scene["rgb_pertrack"][0]*.5+.5),"pose_clusters":pose_cluster_colors}[color_options.value][::stride]

    # Load in frames.
    frame_nodes: list[viser.FrameHandle] = []
    point_nodes: list[viser.PointCloudHandle] = []
    for i in tqdm(range(num_frames)):

        stride=1
        pts = scene["worldcrds_pertrack"][0][::stride]
        #rgb_crds = scene["rgb_pertrack"][0][::stride]*.5+.5
        pose_perpix = scene["poses_all"][0][::stride,i]
        #pts_i=torch.einsum("kij,kj->ki",pose_perpix.inverse(),hom(pts))[...,:3].cpu()
        # normalize to largest cluster
        pts_i=torch.einsum("kij,kj->ki",pose_perpix.inverse()@pose_clusters[[0],i],hom(pts))[...,:3].cpu()

        # todo normalize relative to first pose when dynamic content added

        # Add base frame.
        frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False))

        # Place the point cloud in the frame.
        point_nodes.append( server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=pts_i.numpy(), colors=rgb_crds.numpy(), point_size=gui_point_size.value, point_shape="rounded",))

        #poses=scene["poses"][scene["rig_masks"].flatten(1,-1).sum(-1).max(dim=0)[1]]#[None]
        #T_world_camera = tf.SE3.from_matrix(poses[i].detach().cpu().inverse().numpy())
        #server.scene.add_frame( f"/flowmap/frame_{i}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
        #frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{i}/frustum", fov=1 , aspect=1, scale=0.05,color=[255,0,0])
        colors=[[255,0,0],[0,255,0],[0,0,255],[255,255,0],[255,0,255],[255,0,0] ]
        #for rig_j,poses in enumerate(scene["pose_clusters"][:len(colors)]):
        #for rig_j,poses in enumerate(pose_clusters[:1]):
        #    T_world_camera = tf.SE3.from_matrix(poses[i].detach().cpu().inverse().numpy())
        #    server.scene.add_frame( f"/flowmap/frame_{i}_{rig_j}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
        #    frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{i}/frustum", fov=1 , aspect=1, scale=0.05,color=colors[rig_j])

    # Hide all but the current frame.
    for i, frame_node in enumerate(frame_nodes):
        frame_node.visible = i == gui_timestep.value
    for i, point_node in enumerate(point_nodes):
        point_node.colors = ({"rot":rot_vis,"trans":trans_vis,"aff_pca":pca_aff_emb,"rgb":(scene["rgb_pertrack"][0]*.5+.5),"pose_clusters":pose_cluster_colors}[color_options.value][::stride]*255).int().numpy()

    @server.on_client_connect
    def _(client: viser.ClientHandle) -> None:
        with client.atomic():
            #T_world_set = T_world_camera
            #client.camera.wxyz = T_world_set.rotation().wxyz
            #client.camera.position = T_world_set.translation()
            #client.camera.look_at = frame_nodes[0].position
            client.flush()  # Optional!

    # Playback update loop.
    prev_timestep = gui_timestep.value
    while True:
        #print("running")
        if gui_playing.value: gui_timestep.value = (gui_timestep.value + 1) % num_frames
        point_nodes[gui_timestep.value].point_size = gui_point_size.value
        point_nodes[ (gui_timestep.value + 1) % num_frames ].point_size = gui_point_size.value
        time.sleep(1.0 / gui_framerate.value)

if __name__ == "__main__":
    tyro.cli(main)
