"""Visualize MoGe pointmaps in robot frame with a slider to switch between episodes."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
import cv2
import matplotlib.pyplot as plt
import trimesh
from scipy.spatial.transform import Rotation as R
import torch.nn.functional as F

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0

parser = argparse.ArgumentParser()
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
parser.add_argument("--processed_dir", type=str, default="scratch/processed_grasp_dataset_keyboard", 
                    help="Directory with processed episodes")
args = parser.parse_args()

# Configuration
processed_dir = args.processed_dir
max_episodes = args.max_episodes

print("=" * 60)
print(f"Finding episodes in {processed_dir}")
print("=" * 60)

# Find all available episodes
all_sequences = sorted([d for d in os.listdir(processed_dir) 
                       if os.path.isdir(os.path.join(processed_dir, d))])
sequences = all_sequences[:max_episodes]

print(f"Found {len(all_sequences)} total sequences")
print(f"Loading first {len(sequences)} sequences:")
for i, seq_id in enumerate(sequences):
    print(f"  {i}: {seq_id}")

print("\n" + "=" * 60)
print("Loading episode data")
print("=" * 60)

# Load data for all episodes
episodes_data = []

for seq_id in sequences:
    sequence_dir = os.path.join(processed_dir, seq_id)
    
    # Load start image
    start_img_path = os.path.join(sequence_dir, "start.png")
    start_image = cv2.imread(start_img_path)
    start_image = cv2.cvtColor(start_image, cv2.COLOR_BGR2RGB)
    if start_image.max() <= 1.0:
        start_image = (start_image * 255).astype(np.uint8)
    
    # Load robot-aligned pointmap
    pointmap_path = os.path.join(sequence_dir, "pointmap_start.pt")
    pointmap = torch.load(pointmap_path)
    
    # Robot-aligned pointmap is already (N, 3) format with filtered points
    points = pointmap["points"].cpu().numpy()  # (N, 3) in robot frame
    colors = pointmap["colors"].cpu().numpy()  # (N, 3) RGB colors
    
    # Ensure colors are uint8 [0-255]
    if colors.dtype != np.uint8:
        if colors.max() <= 1.0: colors = (colors * 255).astype(np.uint8)
        else: colors = colors.astype(np.uint8)
    
    # Load DINO features and convert to PCA RGB colors
    dino_features_path = os.path.join(sequence_dir, "dino_features.pt")
    dino_features = torch.load(dino_features_path).numpy()  # (N, 32)
    
    # Map first 3 PCA components to RGB
    dino_pca_rgb = dino_features[:, :3]  # (N, 3)
    dino_pca_rgb_normalized = F.sigmoid(torch.from_numpy(dino_pca_rgb).mul(2.0)).numpy()
    dino_colors = (dino_pca_rgb_normalized * 255).astype(np.uint8)
    
    # Load masks
    robot_mask_path = os.path.join(sequence_dir, "robot_mask.png")
    human_mask_path = os.path.join(sequence_dir, "human_mask.png")
    moge_edge_mask_path = os.path.join(sequence_dir, "moge_edge_mask.png")
    
    robot_mask = plt.imread(robot_mask_path)
    if len(robot_mask.shape) == 3:
        robot_mask = robot_mask[:, :, 0]
    if robot_mask.max() <= 1.0:
        robot_mask = (robot_mask * 255).astype(np.uint8)
    
    human_mask = plt.imread(human_mask_path)
    if len(human_mask.shape) == 3: human_mask = human_mask[:, :, 0]
    if human_mask.max() <= 1.0: human_mask = (human_mask * 255).astype(np.uint8)
    
    moge_edge_mask = plt.imread(moge_edge_mask_path)
    if len(moge_edge_mask.shape) == 3: moge_edge_mask = moge_edge_mask[:, :, 0]
    if moge_edge_mask.max() <= 1.0: moge_edge_mask = (moge_edge_mask * 255).astype(np.uint8)
    
    # Load gripper pose
    gripper_pose_path = os.path.join(sequence_dir, "gripper_pose_grasp.npy")
    gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix
    
    # Transform keypoints from gripper local frame to robot frame
    gripper_rot = gripper_pose[:3, :3]
    gripper_pos = gripper_pose[:3, 3]
    keypoints_robot = (gripper_rot @ KEYPOINTS_LOCAL_M.T).T + gripper_pos.reshape(1, 3)
    
    episodes_data.append({
        'sequence_id': seq_id,
        'points': points,
        'colors': colors,
        'dino_colors': dino_colors,
        'start_image': start_image,
        'robot_mask': robot_mask,
        'human_mask': human_mask,
        'moge_edge_mask': moge_edge_mask,
        'gripper_pose': gripper_pose,
        'keypoints_robot': keypoints_robot,
    })
    
    print(f"  ✓ Loaded {seq_id}: {len(points)} points (robot frame)")

print(f"\nSuccessfully loaded {len(episodes_data)} episodes")

# Load gripper mesh once
print("\n" + "=" * 60)
print("Loading gripper mesh")
print("=" * 60)
#fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
fixed_gripper_stl_path = "robot_models/so100_blender_testings/kp_testing_fixedgrip.stl"
fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
if isinstance(fixed_gripper_mesh, trimesh.Scene):
    fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]
bounds = fixed_gripper_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    fixed_gripper_mesh.apply_scale(0.001)
print(f"  ✓ Loaded gripper mesh: {len(fixed_gripper_mesh.vertices)} vertices")

# Load ball mesh for keypoint visualization
print("\n" + "=" * 60)
print("Loading ball mesh for keypoints")
print("=" * 60)
ball_stl_path = "robot_models/so100_blender_testings/ball.stl"
ball_mesh = trimesh.load(ball_stl_path)
if isinstance(ball_mesh, trimesh.Scene):
    ball_mesh = list(ball_mesh.geometry.values())[0]
bounds = ball_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    ball_mesh.apply_scale(0.001)
print(f"  ✓ Loaded ball mesh: {len(ball_mesh.vertices)} vertices")

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser dataset viewer")
print("=" * 60)
server = viser.ViserServer()

# Store current episode index
current_episode_idx = [0]  # Use list to allow modification in closure

def update_episode(episode_idx):
    """Update visualization to show the selected episode."""
    current_episode_idx[0] = episode_idx
    episode = episodes_data[episode_idx]
    
    # Update pointcloud with RGB colors (robot frame)
    server.scene.add_point_cloud(
        name="/moge_pointmap_robot_rgb",
        points=episode['points'].astype(np.float32),
        colors=episode['colors'].astype(np.uint8),
        point_size=0.002,
    )
    
    # Update pointcloud with DINO PCA colors (robot frame)
    server.scene.add_point_cloud(
        name="/moge_pointmap_robot_dino",
        points=episode['points'].astype(np.float32),
        colors=episode['dino_colors'].astype(np.uint8),
        point_size=0.002,
    )
    
    # Update gripper mesh
    if fixed_gripper_mesh is not None:
        gripper_pose = episode['gripper_pose']
        pos = gripper_pose[:3, 3]
        rot = gripper_pose[:3, :3]
        quat = R.from_matrix(rot).as_quat()  # (x, y, z, w)
        
        try:
            server.scene.add_mesh_trimesh(
                name="/gripper",
                mesh=fixed_gripper_mesh,
                wxyz=quat[[3, 0, 1, 2]],  # (w, x, y, z)
                position=pos,
            )
        except:
            # Fallback: transform vertices manually
            vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
            transformed_vertices = (gripper_pose @ vertices_homogeneous.T).T[:, :3]
            server.scene.add_mesh(
                name="/gripper",
                vertices=transformed_vertices.astype(np.float32),
                faces=fixed_gripper_mesh.faces.astype(np.int32),
                color=(100, 150, 200, 255),
            )
    
    # Update keypoint spheres
    if ball_mesh is not None and episode['keypoints_robot'] is not None:
        keypoints_robot = episode['keypoints_robot']
        for i, kp_pos in enumerate(keypoints_robot):
            print("SUCCES BALL")
            server.scene.add_mesh_trimesh(
                name=f"/keypoint_{i}",
                mesh=ball_mesh,
                wxyz=(1.0, 0.0, 0.0, 0.0),  # No rotation
                position=kp_pos.astype(np.float32),
            )
            

# Initialize with first episode
update_episode(0)

# Add slider to switch between episodes
slider = server.gui.add_slider(
    "episode", 
    0, 
    len(episodes_data) - 1, 
    initial_value=0, 
    step=1
)

@slider.on_update
def _(_):
    """Callback when slider value changes."""
    new_idx = int(slider.value)
    if new_idx != current_episode_idx[0]:
        update_episode(new_idx)

print(f"\nAdded slider to switch between {len(episodes_data)} episodes")
print(f"Use slider to navigate through episodes")
print(f"Episode names:")
for i, ep in enumerate(episodes_data):
    print(f"  {i}: {ep['sequence_id']}")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Dataset viewer with {len(episodes_data)} episodes")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

