import torch
import viser
import matplotlib.pyplot as plt 
import viser.transforms as tf

from sklearn.neighbors import LocalOutlierFactor

# sklearn outlier filter
def filter_outliers(point_cloud, k=20, contamination=0.05):
    return torch.from_numpy(LocalOutlierFactor(n_neighbors=k, contamination=contamination).fit_predict(point_cloud.numpy()) != -1)

scene=torch.load("output/pose_exps/poses_bear.pt",map_location="cpu")

server=viser.ViserServer()

with server.gui.add_folder("Pointcloud"):
    gui_point_size     = server.gui.add_slider( "Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01)
    gui_do_filter = server.gui.add_checkbox( "Filter pointcloud", initial_value=True,)
    gui_outlier_thresh = server.gui.add_slider( "Outlier thresh", min=0.001, max=0.49, step=0.01, initial_value=0.01)

@gui_point_size.on_update
def _(_) -> None: server.point_cloud.point_size = gui_point_size.value
@gui_do_filter.on_update
@gui_outlier_thresh.on_update
def _(_) -> None: 
    if gui_do_filter.value: 
        print("doing filtering")
        server.mask= filter_outliers(pts,contamination=gui_outlier_thresh.value)
    else: server.mask= torch.ones_like(server.mask)

for img_id in range(scene["poses"].size(1)):
    T_world_camera = tf.SE3.from_matrix(scene["poses"][0,img_id].detach().cpu().inverse().numpy())
    server.scene.add_frame( f"/flowmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005,)
    frustum = server.scene.add_camera_frustum( f"/flowmap/frame_{img_id}/frustum", fov=1 , aspect=1, scale=0.05,color=[255,0,0])

pts = scene["world_crds"].flatten(0,2)
rgb_crds = scene["rgb_crds"].flatten(0,2)*.5+.5
server.mask= filter_outliers(pts,contamination=gui_outlier_thresh.value)

while True:
    server.point_cloud        = point_cloud = server.scene.add_point_cloud( name="/flowmap/pcd", points=pts[server.mask].numpy(), colors= rgb_crds[server.mask].numpy(), point_size=gui_point_size.value*.5,)
    print(sum(server.mask)/len(pts))
    pass
