import torch
import viser
import matplotlib.pyplot as plt 
import viser.transforms as tf
import time
from tqdm import tqdm

#sys.path.append(os.getcwd())
#from models import hom
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)

from sklearn.neighbors import LocalOutlierFactor

# sklearn outlier filter
def filter_outliers(point_cloud, k=20, contamination=0.05):
    return torch.from_numpy(LocalOutlierFactor(n_neighbors=k, contamination=contamination).fit_predict(point_cloud.numpy()) != -1)

scene=torch.load("output/pose_exps/poses_bear.pt",map_location="cpu")

server=viser.ViserServer()

with server.gui.add_folder("Pointcloud"):
    gui_point_size     = server.gui.add_slider( "Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01)
    print("todo figure out storing filter per timestep")
    gui_do_filter = server.gui.add_checkbox( "Filter pointcloud", initial_value=False,)
    gui_outlier_thresh = server.gui.add_slider( "Outlier thresh", min=0.001, max=0.49, step=0.01, initial_value=0.01)

num_frames=5#scene["poses"].size(1)
with server.gui.add_folder("Playback"):
    gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True,)
    gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
    gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
    gui_playing = server.gui.add_checkbox("Playing", True)
    gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=1)
    gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60"))
server.prev_timestep = gui_timestep.value
@gui_next_frame.on_click
def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames
@gui_prev_frame.on_click
def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames
@gui_playing.on_update
def _(_) -> None:
    gui_timestep.disabled = gui_playing.value
    gui_next_frame.disabled = gui_playing.value
    gui_prev_frame.disabled = gui_playing.value
@gui_framerate_options.on_click
def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value)
# Toggle frame visibility when the timestep slider changes.
@gui_timestep.on_update
def _(_) -> None:
    print("gui update")
    current_timestep = gui_timestep.value
    with server.atomic():
        # Toggle visibility.
        #frame_nodes[prev_timestep].visible = False
        #frame_nodes[current_timestep].visible = True
        print(current_timestep,prev_timestep)
    server.flush()  # Optional!

@gui_point_size.on_update
def _(_) -> None: server.point_cloud.point_size = gui_point_size.value
@gui_do_filter.on_update
@gui_outlier_thresh.on_update
def _(_) -> None: 
    if gui_do_filter.value: 
        print("doing filtering")
        server.mask= filter_outliers(pts,contamination=gui_outlier_thresh.value)
    else: server.mask= torch.ones_like(server.mask)

for img_id in range(scene["poses"].size(1)):
    T_world_camera = tf.SE3.from_matrix(scene["poses"][0,img_id].detach().cpu().inverse().numpy())
    server.scene.add_frame( f"/flowmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
    frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{img_id}/frustum", fov=1 , aspect=1, scale=0.05,color=[255,0,0])

pts = scene["world_crds"].flatten(0,2)
rgb_crds = scene["rgb_crds"].flatten(0,2)*.5+.5
server.mask= filter_outliers(pts,contamination=gui_outlier_thresh.value)

frame_nodes: list[viser.FrameHandle] = []
point_nodes: list[viser.PointCloudHandle] = []
for i in tqdm(range(num_frames)):
    #frame = loader.get_frame(i)
    #position, color = frame.get_point_cloud(downsample_factor)

    # make point cloud per timestep
    pts_i= torch.einsum("ij,kj->ki",scene["poses"][0,i].inverse(),hom(pts))[...,:3]
    #pc_i = server.scene.add_point_cloud( name="/flowmap/pcd", points=pts_i.numpy(), colors= rgb_crds.numpy(), point_size=gui_point_size.value*.5,)
    #pc_i.visible=False
    #server.scene.point_clouds.append(pc_i)

    # Add base frame.
    frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False))

    # Place the point cloud in the frame.
    #pc=server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=pts_i.numpy(), colors=rgb_crds.numpy(), point_size=gui_point_size.value, point_shape="rounded",)
    #pc.visible=False
    #point_nodes.append( pc)

#server.point_cloud = server.point_clouds[t]
#server.point_cloud.visible=True
# simulate 4d point cloud by moving based on se3 timestep but later normalize relative to first

while True:
    #server.point_cloud        = point_cloud = server.scene.add_point_cloud( name="/flowmap/pcd", points=pts[server.mask].numpy(), colors= rgb_crds[server.mask].numpy(), point_size=gui_point_size.value*.5,)
    #print(sum(server.mask)/len(pts))
    i=gui_timestep.value
    print(i)
    pts_i= torch.einsum("ij,kj->ki",scene["poses"][0,i].inverse(),hom(pts))[...,:3]
    frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False))
    pc=server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=pts_i.numpy(), colors=rgb_crds.numpy(), point_size=gui_point_size.value, point_shape="rounded",)
    if gui_playing.value: 
        server.prev_timestep = gui_timestep.value
        gui_timestep.value = (gui_timestep.value + 1) % num_frames
        print(gui_timestep.value,server.prev_timestep)
        #server.point_clouds[server.prev_timestep].visible=False

        #server.point_clouds[gui_timestep.value].visible=True
        #pts_i= torch.einsum("ij,kj->ki",scene["poses"][0,i].inverse(),hom(pts))[...,:3]
        #frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False))
        #pc=server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=pts_i.numpy(), colors=rgb_crds.numpy(), point_size=gui_point_size.value, point_shape="rounded",)
        time.sleep(1.0 / gui_framerate.value)
