"""Record3D visualizer

Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
"""

import time
from pathlib import Path

import numpy as np
import tyro
from tqdm.auto import tqdm

import viser
import viser.extras
import viser.transforms as tf

import torch
import kornia
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)

def main(
    share: bool = False,
) -> None:
    server = viser.ViserServer()
    if share: server.request_share_url()

    scene=torch.load("/data/cameron/pose_exps/poses_soapbox_rig_mask_testing.pt",map_location="cpu")
    num_frames = min(500,scene["poses"].size(1))

    # Add playback UI.
    with server.gui.add_folder("Playback"):
        #gui_point_size = server.gui.add_slider( "Point size", min=0.001, max=0.1, step=1e-3, initial_value=0.002,)
        gui_point_size = server.gui.add_slider( "Point size", min=0.001, max=0.1, step=1e-3, initial_value=0.025,)
        gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True,)
        gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
        gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
        gui_playing = server.gui.add_checkbox("Playing", True)
        gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=20)
        gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60"))
    # Frame step buttons.
    @gui_next_frame.on_click
    def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames
    @gui_prev_frame.on_click
    def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames
    @gui_playing.on_update
    def _(_) -> None:
        gui_timestep.disabled = gui_playing.value
        gui_next_frame.disabled = gui_playing.value
        gui_prev_frame.disabled = gui_playing.value
    @gui_framerate_options.on_click
    def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value)

    prev_timestep = gui_timestep.value

    # Toggle frame visibility when the timestep slider changes.
    @gui_timestep.on_update
    def _(_) -> None:
        nonlocal prev_timestep
        current_timestep = gui_timestep.value
        with server.atomic():
            # Toggle visibility.
            frame_nodes[current_timestep].visible = True
            frame_nodes[prev_timestep].visible = False
        prev_timestep = current_timestep
        server.flush()  # Optional!

    # Load in frames.
    frame_nodes: list[viser.FrameHandle] = []
    point_nodes: list[viser.PointCloudHandle] = []
    for i in tqdm(range(num_frames)):

        stride=2
        pts = scene["worldcrds_pertrack"][0][::stride]
        rgb_crds = scene["rgb_pertrack"][0][::stride]*.5+.5
        pose_perpix = scene["poses_all"][::stride,i]
        #lie_crds_i=scene["lie_crds"][:,i][::stride]
        #pose_perpix = torch.eye(4)[None].expand(lie_crds_i.size(0),-1,-1).cuda()
        #pose_perpix[...,:3,:3] = kornia.geometry.conversions.quaternion_to_rotation_matrix(lie_crds_i[...,:4])
        #pose_perpix[...,:3,-1] = lie_crds_i[...,4:]

        pts_i=torch.einsum("kij,kj->ki",pose_perpix.inverse(),hom(pts))[...,:3].cpu()
        #pts_i= pts#torch.einsum("ij,kj->ki",scene["poses"][0,i].inverse(),hom(pts))[...,:3]

        # todo normalize relative to first pose when dynamic content added

        # Add base frame.
        frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False))

        # Place the point cloud in the frame.
        point_nodes.append( server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=pts_i.numpy(), colors=rgb_crds.numpy(), point_size=gui_point_size.value, point_shape="rounded",))

        #poses=scene["poses"][scene["rig_masks"].flatten(1,-1).sum(-1).max(dim=0)[1]]#[None]
        #T_world_camera = tf.SE3.from_matrix(poses[i].detach().cpu().inverse().numpy())
        #server.scene.add_frame( f"/flowmap/frame_{i}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
        #frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{i}/frustum", fov=1 , aspect=1, scale=0.05,color=[255,0,0])
        #colors=[[255,0,0],[0,255,0],[0,0,255],[255,255,0],[255,0,255],[255,0,0] ]
        #for rig_j,poses in enumerate(scene["pose_clusters"][:len(colors)]):
        #    T_world_camera = tf.SE3.from_matrix(poses[i].detach().cpu().inverse().numpy())
        #    server.scene.add_frame( f"/flowmap/frame_{i}_{rig_j}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
        #    frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{i}/frustum", fov=1 , aspect=1, scale=0.05,color=colors[rig_j])

    # Hide all but the current frame.
    for i, frame_node in enumerate(frame_nodes):
        frame_node.visible = i == gui_timestep.value

    @server.on_client_connect
    def _(client: viser.ClientHandle) -> None:
        with client.atomic():
            #T_world_set = T_world_camera
            #client.camera.wxyz = T_world_set.rotation().wxyz
            #client.camera.position = T_world_set.translation()
            #client.camera.look_at = frame_nodes[0].position
            client.flush()  # Optional!

    # Playback update loop.
    prev_timestep = gui_timestep.value
    while True:
        #print("running")
        if gui_playing.value: gui_timestep.value = (gui_timestep.value + 1) % num_frames
        point_nodes[gui_timestep.value].point_size = gui_point_size.value
        point_nodes[ (gui_timestep.value + 1) % num_frames ].point_size = gui_point_size.value
        time.sleep(1.0 / gui_framerate.value)

if __name__ == "__main__":
    tyro.cli(main)
