"""Visualize dense dataset with sliders for episode and timestep selection."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
from viser.extras import ViserUrdf
import yourdfpy
import torch.nn.functional as F
from pathlib import Path
import trimesh
from scipy.spatial.transform import Rotation as R

parser = argparse.ArgumentParser()
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
parser.add_argument("--processed_dir", type=str, default="scratch/parsed_episodes_cup", 
                    help="Directory with processed dense episodes")
args = parser.parse_args()

# Configuration
processed_dir = Path(args.processed_dir)
max_episodes = args.max_episodes
DOWNSAMPLE_FACTOR = 9  # Downsample pointclouds by this factor

print("=" * 60)
print(f"Finding episodes in {processed_dir}")
print("=" * 60)

# Find all available episodes
if not processed_dir.exists():
    raise FileNotFoundError(f"Processed directory not found: {processed_dir}")

all_episodes = sorted([d for d in processed_dir.iterdir() 
                      if d.is_dir() and d.name.startswith("episode_")])
episodes = all_episodes[:max_episodes]

if len(episodes) == 0:
    raise ValueError(f"No episodes found in {processed_dir}")

print(f"Found {len(all_episodes)} total episodes")
print(f"Loading first {len(episodes)} episodes:")
for i, ep_dir in enumerate(episodes):
    print(f"  {i}: {ep_dir.name}")

print("\n" + "=" * 60)
print("Loading episode data")
print("=" * 60)

# Load data for all episodes and timesteps
episodes_data = []

for ep_dir in episodes:
    try:
        # Find all timesteps in this episode
        image_files = sorted(ep_dir.glob("*.png"))
        timesteps = []
        
        for img_path in image_files:
            timestep_str = img_path.stem
            joint_path = ep_dir / f"{timestep_str}.npy"
            if joint_path.exists():
                timesteps.append(timestep_str)
        
        if len(timesteps) == 0:
            print(f"  ⚠ Skipping {ep_dir.name}: no timesteps found")
            continue
        
        # Load data for each timestep
        timestep_data = []
        
        for timestep_str in timesteps:
            try:
                # Load joint states
                joint_path = ep_dir / f"{timestep_str}.npy"
                joint_states = np.load(joint_path)
                
                # Load pointcloud
                pointmap_path = ep_dir / f"pointmap_{timestep_str}_raw.pt"
                if not pointmap_path.exists():
                    print(f"    ⚠ Skipping timestep {timestep_str}: missing pointmap")
                    continue
                
                pointmap = torch.load(pointmap_path)
                points_cam_flat = pointmap["points"].numpy()  # (N, 3) in camera frame
                colors_flat = pointmap["colors"].numpy()  # (N, 3)
                mask_flat = pointmap["mask"].numpy()  # (N,)
                
                # Load Procrustes transformation
                moge_to_robot_frame_path = ep_dir / f"moge_to_robot_frame_{timestep_str}.npy"
                if not moge_to_robot_frame_path.exists():
                    print(f"    ⚠ Skipping timestep {timestep_str}: missing Procrustes transformation")
                    continue
                
                T_procrustes = np.load(moge_to_robot_frame_path)  # (4, 4) - transforms MoGe to robot frame
                
                # Filter to valid points
                valid_mask = mask_flat.astype(bool)
                points_cam_valid = points_cam_flat[valid_mask]
                colors_valid = colors_flat[valid_mask]
                
                # Transform points from camera frame to robot frame using Procrustes
                if len(points_cam_valid) > 0:
                    points_cam_h = np.hstack([points_cam_valid, np.ones((len(points_cam_valid), 1))])  # (N, 4)
                    points_robot = (T_procrustes @ points_cam_h.T).T[:, :3]  # (N, 3)
                else:
                    points_robot = np.zeros((0, 3))
                
                # Downsample uniformly by DOWNSAMPLE_FACTOR
                if len(points_robot) > 0:
                    indices = np.arange(0, len(points_robot), DOWNSAMPLE_FACTOR)
                    points_downsampled = points_robot[indices]
                    colors_downsampled = colors_valid[indices]
                else:
                    points_downsampled = np.zeros((0, 3))
                    colors_downsampled = np.zeros((0, 3))
                
                # Load DINO features if available
                dino_features_path = ep_dir / f"dino_features_{timestep_str}.pt"
                dino_colors = None
                if dino_features_path.exists():
                    dino_features = torch.load(dino_features_path).numpy()  # (N, 32)
                    # Map first 3 PCA components to RGB
                    dino_rgb = F.sigmoid(torch.from_numpy(dino_features[:, :3]).mul(2.0)).numpy()
                    dino_colors = (dino_rgb * 255).astype(np.uint8)
                    # Downsample DINO colors to match pointcloud
                    if len(dino_colors) > 0:
                        dino_colors = dino_colors[indices]
                
                timestep_data.append({
                    'timestep_str': timestep_str,
                    'joint_states': joint_states,
                    'points': points_downsampled,
                    'colors': colors_downsampled,
                    'dino_colors': dino_colors,
                })
                
            except Exception as e:
                print(f"    ✗ Failed to load timestep {timestep_str}: {e}")
                continue
        
        if len(timestep_data) == 0:
            print(f"  ⚠ Skipping {ep_dir.name}: no valid timesteps")
            continue
        
        episodes_data.append({
            'episode_id': ep_dir.name,
            'timesteps': timestep_data,
        })
        
        print(f"  ✓ Loaded {ep_dir.name}: {len(timestep_data)} timesteps")
        
    except Exception as e:
        print(f"  ✗ Failed to load {ep_dir.name}: {e}")
        continue

if len(episodes_data) == 0:
    raise ValueError("No episodes could be loaded!")

print(f"\nSuccessfully loaded {len(episodes_data)} episodes")
total_timesteps = sum(len(ep['timesteps']) for ep in episodes_data)
print(f"Total timesteps: {total_timesteps}")

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser dataset viewer")
print("=" * 60)
server = viser.ViserServer()

# Load robot URDF once
urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
urdf = yourdfpy.URDF.load(urdf_path)
viser_urdf = ViserUrdf(
    server,
    urdf_or_path=urdf,
    load_meshes=True,
    load_collision_meshes=False,
    collision_mesh_color_override=(1.0, 0.0, 0.0, 0.5),
)

mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])

# Load gripper STL mesh once
print("\n" + "=" * 60)
print("Loading gripper STL mesh")
print("=" * 60)
fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
if isinstance(fixed_gripper_mesh, trimesh.Scene):
    fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]

# Check if mesh is in mm and scale to meters if needed
bounds = fixed_gripper_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    fixed_gripper_mesh.apply_scale(0.001)
print(f"  ✓ Loaded gripper mesh: {len(fixed_gripper_mesh.vertices)} vertices")

# Store current episode and timestep indices
current_episode_idx = [0]  # Use list to allow modification in closure
current_timestep_idx = [0]

def update_visualization(episode_idx, timestep_idx):
    """Update visualization to show the selected episode and timestep."""
    if not (0 <= episode_idx < len(episodes_data)):
        return
    
    episode = episodes_data[episode_idx]
    timesteps = episode['timesteps']
    
    if not (0 <= timestep_idx < len(timesteps)):
        return
    
    current_episode_idx[0] = episode_idx
    current_timestep_idx[0] = timestep_idx
    
    timestep = timesteps[timestep_idx]
    
    # Update robot URDF
    joint_states = timestep['joint_states']
    viser_urdf.update_cfg(np.array(joint_states - mujoco_so100_offset))
    
    # Update pointcloud
    server.scene.add_point_cloud(
        name="/moge_pointcloud",
        points=timestep['points'].astype(np.float32),
        colors=timestep['colors'].astype(np.uint8),
        point_size=0.002,
    )
    
    # Update DINO feature pointcloud if available
    if timestep['dino_colors'] is not None:
        server.scene.add_point_cloud(
            name="/dino_features",
            points=timestep['points'].astype(np.float32),
            colors=timestep['dino_colors'].astype(np.uint8),
            point_size=0.002,
        )
    
    # Check for gripper pose file and add gripper STL if it exists
    episode = episodes_data[episode_idx]
    timestep_str = timestep['timestep_str']
    episode_dir = processed_dir / episode['episode_id']
    gripper_pose_path = episode_dir / f"{timestep_str}_gripper_pose.npy"
    
    if gripper_pose_path.exists():
        gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix
        
        # Extract position and rotation from pose
        pos = gripper_pose[:3, 3]
        rot = gripper_pose[:3, :3]
        quat = R.from_matrix(rot).as_quat()  # (x, y, z, w)
        wxyz = quat[[3, 0, 1, 2]]  # Convert to (w, x, y, z)
        
        # Add gripper mesh to viser
        try:
            server.scene.add_mesh_trimesh(
                name="/gripper",
                mesh=fixed_gripper_mesh,
                wxyz=wxyz,
                position=pos.astype(np.float32),
            )
        except:
            # Fallback: transform vertices manually and use add_mesh
            vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
            transformed_vertices = (gripper_pose @ vertices_homogeneous.T).T[:, :3]
            server.scene.add_mesh(
                name="/gripper",
                vertices=transformed_vertices.astype(np.float32),
                faces=fixed_gripper_mesh.faces.astype(np.int32),
                color=(100, 150, 200, 255),
            )

# Initialize with first episode and timestep
update_visualization(0, 0)

# Add slider for episode selection
episode_slider = server.gui.add_slider(
    "episode", 
    0, 
    len(episodes_data) - 1, 
    initial_value=0, 
    step=1
)

# Get initial timestep count for first episode
initial_num_timesteps = len(episodes_data[0]['timesteps']) if len(episodes_data) > 0 else 1

# Add slider for timestep selection
timestep_slider = server.gui.add_slider(
    "timestep",
    0,
    max(0, initial_num_timesteps - 1),
    initial_value=0,
    step=1
)

def update_timestep_slider_range():
    """Update timestep slider range based on current episode."""
    if current_episode_idx[0] >= len(episodes_data):
        return
    
    episode = episodes_data[current_episode_idx[0]]
    num_timesteps = len(episode['timesteps'])
    new_max = max(0, num_timesteps - 1)
    
    # Only update if the max actually changed to avoid slider issues
    if timestep_slider.max != new_max:
        timestep_slider.max = new_max
    
    # Clamp current timestep index to valid range
    if num_timesteps == 0:
        current_timestep_idx[0] = 0
        timestep_slider.value = 0
    elif current_timestep_idx[0] >= num_timesteps:
        current_timestep_idx[0] = num_timesteps - 1
        timestep_slider.value = num_timesteps - 1
    else:
        # Ensure slider value is valid
        current_val = int(timestep_slider.value) if not np.isnan(timestep_slider.value) else 0
        if current_val < 0 or current_val >= num_timesteps:
            timestep_slider.value = 0
            current_timestep_idx[0] = 0

@episode_slider.on_update
def _(_):
    """Callback when episode slider value changes."""
    new_episode_idx = int(episode_slider.value)
    if new_episode_idx != current_episode_idx[0] and 0 <= new_episode_idx < len(episodes_data):
        # Update timestep slider range for new episode
        update_timestep_slider_range()
        # Reset to first timestep of new episode
        current_timestep_idx[0] = 0
        timestep_slider.value = 0
        update_visualization(new_episode_idx, 0)

@timestep_slider.on_update
def _(_):
    """Callback when timestep slider value changes."""
    # Check for NaN or invalid values
    if np.isnan(timestep_slider.value):
        timestep_slider.value = 0
        return
    
    new_timestep_idx = int(timestep_slider.value)
    
    # Validate range
    if current_episode_idx[0] < len(episodes_data):
        episode = episodes_data[current_episode_idx[0]]
        num_timesteps = len(episode['timesteps'])
        if new_timestep_idx < 0 or new_timestep_idx >= num_timesteps:
            new_timestep_idx = max(0, min(new_timestep_idx, num_timesteps - 1))
            timestep_slider.value = new_timestep_idx
        
        if new_timestep_idx != current_timestep_idx[0]:
            update_visualization(current_episode_idx[0], new_timestep_idx)

# Initialize timestep slider range
update_timestep_slider_range()

print(f"\nAdded sliders:")
print(f"  - Episode: 0 to {len(episodes_data) - 1}")
print(f"  - Timestep: 0 to {timestep_slider.max}")
print(f"\nEpisode names:")
for i, ep in enumerate(episodes_data):
    num_timesteps = len(ep['timesteps'])
    print(f"  Episode {i} ({ep['episode_id']}): {num_timesteps} timesteps")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Dataset viewer with {len(episodes_data)} episodes, {total_timesteps} total timesteps")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

