"""COLMAP visualizer

Visualize COLMAP sparse reconstruction outputs. To get demo data, see `./assets/download_colmap_garden.sh`.
"""

import random
import time
from pathlib import Path
from typing import List
import sklearn
from sklearn.ensemble import IsolationForest

import imageio.v3 as iio
import numpy as np
import tyro
from tqdm.auto import tqdm

import viser
import viser.transforms as tf
from viser.extras.colmap import (
    read_cameras_binary,
    read_images_binary,
    read_points3d_binary,
)

import torch

def main(
        flowmap_exp_path : str= "/home/cameronsmith/repos/flowmap++/output/pose_exps/poses_static_test.pt",
) -> None:
    """Visualize flowmap output

    Args:
        flowmap_exp_path: Path to the flowmap .pt exp.
    """
    server = viser.ViserServer()
    server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")

    # Load the colmap info.
    #cameras = read_cameras_binary(colmap_path / "cameras.bin")
    #images = read_images_binary(colmap_path / "images.bin")
    #points3d = read_points3d_binary(colmap_path / "points3D.bin")
    scene = torch.load(flowmap_exp_path,map_location="cpu")
    gui_reset_up = server.gui.add_button(
        "Reset up direction",
        hint="Set the camera control 'up' direction to the current camera's 'up'.",
    )

    points = points3d = scene["world_crds"].flatten(0,1).numpy()
    colors = ((scene["rgb"].flatten(0,1)*.5+.5)*255).int().numpy()
    images = ((scene["rgb"].unflatten(1,scene["flow_inp_"].shape[-2:])*.5+.5)*255).int().numpy()

    @gui_reset_up.on_click
    def _(event: viser.GuiEvent) -> None:
        client = event.client
        assert client is not None
        client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
            [-1.0, 0.0, 0.0]
        )

    gui_outlier_filter = server.gui.add_checkbox(
        "Outlier Filtering",
        initial_value=False,
    )

    gui_outlier_slider = server.gui.add_slider(
        "Outlier filter threshold",
        min=.00001,
        max=.49,
        step=.01,
        initial_value=.1,
    )
    gui_points = server.gui.add_slider(
        "Max points",
        min=1,
        max=len(points3d),
        step=1,
        initial_value=len(points3d),
    )
    gui_frames = server.gui.add_slider(
        "Max frames",
        min=1,
        max=len(images),
        step=1,
        initial_value=min(len(images), 100),
    )
    gui_point_size = server.gui.add_slider(
        "Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01
    )

    point_mask = np.random.choice(points.shape[0], gui_points.value, replace=True)
    point_cloud = server.scene.add_point_cloud(
        name="/flowmap/pcd",
        points=points[point_mask],
        colors=colors[point_mask],
        point_size=gui_point_size.value,
    )
    frames: List[viser.FrameHandle] = []

    def visualize_frames() -> None:
        """Send all elements to viser for visualization. This could be optimized
        a ton!  """

        # Remove existing image frames.
        for frame in frames: frame.remove()
        frames.clear()

        def attach_callback(
            frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
        ) -> None:
            @frustum.on_click
            def _(_) -> None:
                for client in server.get_clients().values():
                    client.camera.wxyz = frame.wxyz
                    client.camera.position = frame.position

        for src_i,pose_src in enumerate(["poses","gt_poses"][:1]):
            for frame_i,pose in enumerate(tqdm(scene[pose_src])):

                T_world_camera = tf.SE3.from_matrix(pose.inverse().numpy())
                frame = server.scene.add_frame(
                    f"/flowmap/{pose_src}_frame_{frame_i}",
                    wxyz=T_world_camera.rotation().wxyz,
                    position=T_world_camera.translation(),
                    axes_length=0.1,
                    axes_radius=0.005,
                    show_axes=src_i==0,
                )
                frames.append(frame)

                # For pinhole cameras, cam.params will be (fx, fy, cx, cy).
                H, W = images.shape[1],images.shape[2]
                fy = scene["gt_intrinsics"].numpy()[0,1,1]*H
                downsample_factor=1
                image = images[frame_i][::downsample_factor, ::downsample_factor]
                frustum = server.scene.add_camera_frustum(
                    f"/flowmap/{pose_src}_frame_{frame_i}/frustum",
                    fov=2 * np.arctan2(H / 2, fy),
                    aspect=W / H,
                    scale=.025 if src_i else 0.05,
                    image=image,
                    color=[255,0,0] if src_i else [0,0,0],
                )
                attach_callback(frustum, frame)

    need_update = True

    @gui_outlier_filter.on_update
    @gui_outlier_slider.on_update
    def _(_) -> None:
        iso_forest = IsolationForest(contamination=gui_outlier_slider.value).fit(points)
        outlier_scores = iso_forest.decision_function(points)
        is_outlier = iso_forest.predict(points) == -1
        point_cloud.outliers=is_outlier

    @gui_points.on_update
    @gui_outlier_slider.on_update
    @gui_outlier_filter.on_update
    def _(_) -> None:
        valid_point_idxs = np.arange(points.shape[0])
        if gui_outlier_filter.value:
            valid_point_idxs=valid_point_idxs[~point_cloud.outliers]
        point_mask = np.random.choice(valid_point_idxs, gui_points.value, replace=True)
        point_cloud.points = points[point_mask]
        point_cloud.colors = colors[point_mask]

    @gui_frames.on_update
    def _(_) -> None:
        nonlocal need_update
        need_update = True

    @gui_point_size.on_update
    def _(_) -> None:
        point_cloud.point_size = gui_point_size.value

    while True:
        if need_update:
            need_update = False
            visualize_frames()

        time.sleep(1e-3)


if __name__ == "__main__":
    tyro.cli(main)
