"""Visualize episode with MoGE pointcloud in robot frame using viser."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'Demos'))

import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
import mujoco
import viser
from viser.extras import ViserUrdf
import yourdfpy
from scipy.spatial.transform import Rotation as R
import trimesh

from demo_utils import procrustes_alignment
from ExoConfigs import EXOSKELETON_CONFIGS
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (detect_and_set_link_poses, detect_and_position_alignment_board, 
                       combine_xmls)
import fpsample
import utils3d

parser = argparse.ArgumentParser()
parser.add_argument("--episode_num", type=int, default=1)
args = parser.parse_args()


# Create viser server
server = viser.ViserServer()

# Configuration
episode_num = args.episode_num
dataset_dir = "scratch/dataset"
episode_dir = os.path.join(dataset_dir, f"episode_{episode_num}")
target_points = 4000  # Target number of points after FPS downsampling

# Optimized volume bounds
volume_bounds = {
    "x_min": -0.33,
    "x_max": 0.08,
    "y_min": -0.5,
    "y_max": -0.13,
    "z_min": -0.01,
    "z_max": 0.08,
}

print("=" * 60)
print("Loading Episode Data")
print("=" * 60)

# Load saved data
start_img_path = os.path.join(episode_dir, "start.png")
pointmap_path = os.path.join(episode_dir, "pointmap_start.pt")
joint_states_path = os.path.join(episode_dir, "joint_states_start.npy")
camera_pose_path = os.path.join(episode_dir, "robot_camera_pose.npy")

# Load image
rgb = cv2.imread(start_img_path)
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)

# Load pointmap
pointmap_data = torch.load(pointmap_path)
points = pointmap_data["points"].numpy()  # (H, W, 3)
mask_orig = pointmap_data["mask"].numpy()
# Apply depth neighbor mask (filter out depth edges)
mask = mask_orig & ~utils3d.np.depth_map_edge(points[:,:,2], rtol=0.005)

# Load joint states
joint_states = np.load(joint_states_path)

# Load camera pose
camera_pose_world = np.load(camera_pose_path)

# Load robot and human masks
robot_mask_path = os.path.join(episode_dir, "start_robot_mask.png")
human_mask_path = os.path.join(episode_dir, "human_mask_start.png")

H, W = points.shape[:2]
robot_mask = None
human_mask = None

if os.path.exists(robot_mask_path):
    robot_mask_img = plt.imread(robot_mask_path)
    if len(robot_mask_img.shape) == 3:
        robot_mask_img = robot_mask_img[:,:,0]  # Take first channel if RGB
    # Normalize if needed
    if robot_mask_img.max() <= 1.0:
        robot_mask_img = robot_mask_img * 255
    # Resize to match pointmap dimensions
    from PIL import Image as PILImage
    robot_mask_img = np.array(PILImage.fromarray(robot_mask_img.astype(np.uint8)).resize((W, H), PILImage.Resampling.LANCZOS))
    robot_mask = robot_mask_img > 127  # Threshold to boolean
    print(f"  Loaded robot mask: {robot_mask.sum()} pixels")

if os.path.exists(human_mask_path):
    human_mask_img = plt.imread(human_mask_path)
    if len(human_mask_img.shape) == 3:
        human_mask_img = human_mask_img[:,:,0]  # Take first channel if RGB
    # Normalize if needed
    if human_mask_img.max() <= 1.0:
        human_mask_img = human_mask_img * 255
    # Resize to match pointmap dimensions
    human_mask_img = np.array(PILImage.fromarray(human_mask_img.astype(np.uint8)).resize((W, H), PILImage.Resampling.LANCZOS))
    human_mask = human_mask_img > 127  # Threshold to boolean
    print(f"  Loaded human mask: {human_mask.sum()} pixels")

# Combine masks: exclude robot and human pixels from MoGE mask
if robot_mask is not None:
    mask = mask & ~robot_mask
if human_mask is not None:
    mask = mask & ~human_mask

print(f"Loaded episode {episode_num}")
print(f"  Pointmap shape: {points.shape}")
print(f"  Valid points after masking: {mask.sum()}")

print("\n" + "=" * 60)
print("Setting up robot and detecting ArUco markers")
print("=" * 60)

# Setup robot
robot_config = EXOSKELETON_CONFIGS["so100_adhesive"]
mj_model = mujoco.MjModel.from_xml_string(combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition()))
mj_data = mujoco.MjData(mj_model)

# Set robot state
mj_data.qpos[:] = joint_states
mj_data.ctrl[:] = joint_states[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)

# Detect ArUco markers to get obj_img_pts for alignment
link_poses, camera_pose_world_detected, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
    rgb, mj_model, mj_data, robot_config)

# Detect and position alignment board
board_result = detect_and_position_alignment_board( rgb, mj_model, mj_data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False)
board_pose, board_pts = board_result
obj_img_pts["alignment_board"] = board_pts
print(f"Alignment board detected ({len(board_pts[1])} points)")

# Use larger base and alignment board for alignment
obj_img_pts = {
    "larger_base": obj_img_pts["larger_base"],
    "alignment_board": obj_img_pts["alignment_board"]
}

print("\n" + "=" * 60)
print("Aligning MoGE pointcloud to robot frame")
print("=" * 60)

# Extract valid points
points_flat = points.reshape(-1, 3)
mask_flat = mask.reshape(-1)
valid_points = points_flat[mask_flat > 0.5]

# Sample ArUco 3D points from MOGE pointmap
moge_aruco_corners = []
aruco_corners_robot_frame = []

for obj_img_pts_3d, img_pts in obj_img_pts.values():
    # Get 3D points in robot frame (using loaded camera pose)
    aruco_3d = (np.linalg.inv(camera_pose_world) @ np.hstack([obj_img_pts_3d, np.ones((obj_img_pts_3d.shape[0], 1))]).T).T[:, :3]
    aruco_corners_robot_frame.extend(aruco_3d)
    
    # Sample from MOGE pointmap - direct pixel coordinates
    for pt in img_pts:
        x, y = int(pt[0]), int(pt[1])
        if 0 <= y < points.shape[0] and 0 <= x < points.shape[1]:
            moge_aruco_corners.append(points[y, x])

aruco_corners_robot_frame = np.array(aruco_corners_robot_frame)
moge_aruco_corners = np.array(moge_aruco_corners)

print(f"ArUco points for alignment: {len(aruco_corners_robot_frame)} robot frame, {len(moge_aruco_corners)} MOGE")

# Procrustes alignment: align MOGE pointcloud to robot frame
T_procrustes, scale, rotation, translation = procrustes_alignment(aruco_corners_robot_frame, moge_aruco_corners)
points_homogeneous = np.hstack([valid_points, np.ones((len(valid_points), 1))])
moge_aligned = (T_procrustes @ points_homogeneous.T).T[:, :3]

print(f"Aligned pointcloud: {len(moge_aligned)} points")

# Get colors for visualization
colors = rgb.reshape(-1, 3)
valid_colors = colors[mask_flat > 0.5]

print("\n" + "=" * 60)
print("Downsampling uncropped pointcloud")
print("=" * 60)

# Downsample uncropped pointcloud for viser (uniform then FPS)
uniform_ds_factor = 21
if len(moge_aligned) > uniform_ds_factor:
    moge_aligned_uncropped_uniform = moge_aligned[::uniform_ds_factor]
    valid_colors_uncropped_uniform = valid_colors[::uniform_ds_factor]
else:
    moge_aligned_uncropped_uniform = moge_aligned
    valid_colors_uncropped_uniform = valid_colors

num_points_uncropped = min(target_points, len(moge_aligned_uncropped_uniform))
if len(moge_aligned_uncropped_uniform) > num_points_uncropped:
    indices_uncropped = fpsample.fps_npdu_kdtree_sampling(moge_aligned_uncropped_uniform, num_points_uncropped)
    moge_uncropped_points = moge_aligned_uncropped_uniform[indices_uncropped]
    moge_uncropped_colors = valid_colors_uncropped_uniform[indices_uncropped]
else:
    moge_uncropped_points = moge_aligned_uncropped_uniform
    moge_uncropped_colors = valid_colors_uncropped_uniform

print(f"Uncropped pointcloud downsampled: {len(moge_aligned)} -> {len(moge_uncropped_points)} points")

print("\n" + "=" * 60)
print("Cropping pointcloud with volume bounds (before downsampling)")
print("=" * 60)

if 1:

    # Crop pointcloud using optimized volume bounds FIRST
    mask_filtered = (
        (moge_aligned[:, 0] >= volume_bounds["x_min"]) & (moge_aligned[:, 0] <= volume_bounds["x_max"]) &
        (moge_aligned[:, 1] >= volume_bounds["y_min"]) & (moge_aligned[:, 1] <= volume_bounds["y_max"]) &
        (moge_aligned[:, 2] >= volume_bounds["z_min"]) & (moge_aligned[:, 2] <= volume_bounds["z_max"])
    )
    moge_aligned_cropped = moge_aligned[mask_filtered]
    valid_colors_cropped = valid_colors[mask_filtered]

    print(f"Cropped pointcloud: {len(moge_aligned)} -> {len(moge_aligned_cropped)} points")
    print(f"Volume bounds: x=[{volume_bounds['x_min']:.2f}, {volume_bounds['x_max']:.2f}], "
        f"y=[{volume_bounds['y_min']:.2f}, {volume_bounds['y_max']:.2f}], "
        f"z=[{volume_bounds['z_min']:.2f}, {volume_bounds['z_max']:.2f}]")

    print("\n" + "=" * 60)
    print("Downsampling with FPS")
    print("=" * 60)

    # Uniform downsampling first (21x) to speed up FPS
    #uniform_ds_factor = 21
    #if len(moge_aligned_cropped) > uniform_ds_factor:
    #    moge_aligned_uniform = moge_aligned_cropped[::uniform_ds_factor]
    #    valid_colors_uniform = valid_colors_cropped[::uniform_ds_factor]
    #    print(f"Uniform downsampling ({uniform_ds_factor}x): {len(moge_aligned_cropped)} -> {len(moge_aligned_uniform)} points")
    #else:
    moge_aligned_uniform = moge_aligned_cropped
    valid_colors_uniform = valid_colors_cropped

    # Downsample to target_points using FPS
    num_points = min(target_points, len(moge_aligned_uniform))
    if len(moge_aligned_uniform) > num_points:
        # FPS sampling returns indices
        indices = fpsample.fps_npdu_kdtree_sampling(moge_aligned_uniform, num_points)
        filtered_points = moge_aligned_uniform[indices]
        filtered_colors = valid_colors_uniform[indices]
        print(f"FPS downsampling: {len(moge_aligned_uniform)} -> {len(filtered_points)} points")
    else:
        filtered_points = moge_aligned_uniform
        filtered_colors = valid_colors_uniform
        print(f"Using all {len(filtered_points)} points (no FPS downsampling needed)")

    print(f"Total downsampling: {len(moge_aligned_cropped)} -> {len(filtered_points)} points")

    # Add uncropped pointcloud (initially hidden/low visibility)
    server.scene.add_point_cloud(
        name="/moge_uncropped",
        points=moge_aligned,
        colors=valid_colors,
        point_size=0.001,  # Very small to make it less visible
    )

    # Add cropped pointcloud
    server.scene.add_point_cloud(
        name="/moge_cropped",
        points=filtered_points,
        colors=filtered_colors,
        point_size=0.001,
    )

print("\n" + "=" * 60)
print("Loading grasp joint states")
print("=" * 60)

# Load all joint states (start + grasps)
joint_states_list = [("start", joint_states)]  # Start state

# Find all grasp joint state files
grasp_files = sorted([f for f in os.listdir(episode_dir) 
                     if f.startswith("joint_states_grasp") and f.endswith(".npy")])
for grasp_file in grasp_files:
    grasp_path = os.path.join(episode_dir, grasp_file)
    grasp_states = np.load(grasp_path)
    # Extract grasp number from filename (e.g., "joint_states_grasp1.npy" -> "grasp1")
    grasp_name = grasp_file.replace("joint_states_", "").replace(".npy", "")
    joint_states_list.append((grasp_name, grasp_states))
    print(f"  Loaded {grasp_name}")

print(f"Total robot states: {len(joint_states_list)}")
print(f"  - start")
for i, (name, _) in enumerate(joint_states_list[1:], 1):
    print(f"  - {name}")

print("\n" + "=" * 60)
print("Loading fixed_gripper exo_link poses")
print("=" * 60)

# Load fixed_gripper exo_link poses for all states
gripper_poses_list = []

for state_name, _ in joint_states_list:
    gripper_pose_path = os.path.join(episode_dir, f"fixed_gripper_exo_pose_{state_name}.npy")
    if os.path.exists(gripper_pose_path):
        gripper_pose = np.load(gripper_pose_path)
        gripper_poses_list.append((state_name, gripper_pose))
        print(f"  Loaded {state_name}: position = [{gripper_pose[0,3]:.3f}, {gripper_pose[1,3]:.3f}, {gripper_pose[2,3]:.3f}]")
    else:
        print(f"  Warning: {gripper_pose_path} not found")

print("\n" + "=" * 60)
print("Loading fixed_gripper STL mesh")
print("=" * 60)

# Load fixed_gripper exo STL file
# BLENDER_STL_DIR is relative to ExoConfigs directory
fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"

fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
# Check if it's a Scene (multiple meshes) or single mesh
if isinstance(fixed_gripper_mesh, trimesh.Scene):
    # Get the first mesh from the scene
    fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]

# STL might be in mm, convert to meters if needed
# Check mesh bounds - if max is > 1m, probably in mm
bounds = fixed_gripper_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:  # Likely in mm
    print(f"  Converting mesh from mm to meters (scale: {max_extent:.1f}mm)")
    fixed_gripper_mesh.apply_scale(0.001)  # Convert mm to meters

print(f"  Loaded fixed_gripper mesh from: {fixed_gripper_stl_path}")
print(f"  Mesh vertices: {len(fixed_gripper_mesh.vertices)}")
print(f"  Mesh faces: {len(fixed_gripper_mesh.faces)}")
print(f"  Mesh bounds: {fixed_gripper_mesh.bounds}")
print(f"  Mesh center: {fixed_gripper_mesh.center_mass}")

print("\n" + "=" * 60)
print("Setting up Viser visualization")
print("=" * 60)

# Add fixed_gripper exo_link coordinate frames and meshes
for state_name, gripper_pose in gripper_poses_list:
    pos = gripper_pose[:3, 3]
    rot = gripper_pose[:3, :3]
    
    # Convert rotation matrix to quaternion (w, x, y, z) for viser
    quat = R.from_matrix(rot).as_quat()  # Returns (x, y, z, w)
    
    # Add fixed_gripper STL mesh at gripper pose
    if fixed_gripper_mesh is not None:
        mesh_name = f"/fixed_gripper_mesh_{state_name}"
        
        # Transform mesh vertices by gripper pose
        transform = gripper_pose.copy()
        vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
        transformed_vertices = (transform @ vertices_homogeneous.T).T[:, :3]
        
        # Try different methods to add mesh
        mesh_added = False
        
        # Method 1: Try add_mesh_trimesh (newer API)
        if not mesh_added:
            try:
                server.scene.add_mesh_trimesh(
                    name=mesh_name,
                    mesh=fixed_gripper_mesh,
                    wxyz=quat[[3, 0, 1, 2]],  # Convert to (w, x, y, z)
                    position=pos,
                )
                print(f"  Added gripper mesh (trimesh): {mesh_name}")
                mesh_added = True
            except (AttributeError, TypeError) as e:
                print(f"  add_mesh_trimesh failed: {e}")
        
        # Method 2: Try add_mesh with transformed vertices
        if not mesh_added:
            try:
                server.scene.add_mesh(
                    name=mesh_name,
                    vertices=transformed_vertices.astype(np.float32),
                    faces=fixed_gripper_mesh.faces.astype(np.int32),
                    color=(100, 150, 200, 255),  # RGBA color
                )
                print(f"  Added gripper mesh (add_mesh): {mesh_name}")
                mesh_added = True
            except (AttributeError, TypeError) as e:
                print(f"  add_mesh failed: {e}")
        
        # Method 3: Try add_mesh_simple
        if not mesh_added:
            try:
                server.scene.add_mesh_simple(
                    name=mesh_name,
                    vertices=transformed_vertices.astype(np.float32),
                    faces=fixed_gripper_mesh.faces.astype(np.int32),
                )
                print(f"  Added gripper mesh (simple): {mesh_name}")
                mesh_added = True
            except (AttributeError, TypeError) as e:
                print(f"  add_mesh_simple failed: {e}")
        
        if not mesh_added:
            print(f"  WARNING: Could not add gripper mesh {mesh_name} - no valid method found")
            # Debug: print available methods
            print(f"  Available scene methods: {[m for m in dir(server.scene) if 'mesh' in m.lower()]}")


# Load URDF robot arm
urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
urdf = yourdfpy.URDF.load(urdf_path)
viser_urdf = ViserUrdf(
    server,
    urdf_or_path=urdf,
    load_meshes=True,
    load_collision_meshes=False,
    collision_mesh_color_override=(1.0, 0.0, 0.0, 0.5),
)
mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])

def update_robot_state(state_idx):
    """Update robot URDF configuration to the selected state."""
    if 0 <= state_idx < len(joint_states_list):
        name, states = joint_states_list[state_idx]
        viser_urdf.update_cfg(np.array(states - mujoco_so100_offset))
        print(f"Updated robot to: {name}")

# Set initial state (start)
update_robot_state(0)
print("Loaded robot URDF into viser")

# Add slider to select robot state
if len(joint_states_list) > 1:
    slider = server.gui.add_slider(
        "robot_state", 
        0, 
        len(joint_states_list) - 1, 
        initial_value=0, 
        step=1
    )
    slider.on_update(lambda _: update_robot_state(int(slider.value)))
    print(f"Added slider to select robot state (0-{len(joint_states_list)-1})")

print("\n" + "=" * 60)
print("Viser server running!")
print(f"Episode {episode_num} - MoGE pointcloud (cropped)")
print("Use slider to change robot state")
print("Press Ctrl+C to exit")
print("=" * 60)

import pdb; pdb.set_trace()

