"""Visualize multiple keyboard grasp episodes with a slider to switch between them."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
from viser.extras import ViserUrdf
import yourdfpy
import torch.nn.functional as F

parser = argparse.ArgumentParser()
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
parser.add_argument("--processed_dir", type=str, default="scratch/processed_grasp_dataset_keyboard", 
                    help="Directory with processed episodes")
args = parser.parse_args()

# Configuration
processed_dir = args.processed_dir
max_episodes = args.max_episodes

print("=" * 60)
print(f"Finding episodes in {processed_dir}")
print("=" * 60)

# Find all available episodes
if not os.path.exists(processed_dir):
    raise FileNotFoundError(f"Processed directory not found: {processed_dir}")

all_sequences = sorted([d for d in os.listdir(processed_dir) 
                       if os.path.isdir(os.path.join(processed_dir, d))])
sequences = all_sequences[:max_episodes]

if len(sequences) == 0:
    raise ValueError(f"No sequences found in {processed_dir}")

print(f"Found {len(all_sequences)} total sequences")
print(f"Loading first {len(sequences)} sequences:")
for i, seq_id in enumerate(sequences):
    print(f"  {i}: {seq_id}")

print("\n" + "=" * 60)
print("Loading episode data")
print("=" * 60)

# Load data for all episodes
episodes_data = []

for seq_id in sequences:
    sequence_dir = os.path.join(processed_dir, seq_id)
    
    try:
        # Load joint states
        joint_states_path = os.path.join(sequence_dir, "joint_states_grasp.npy")
        if not os.path.exists(joint_states_path):
            print(f"  ⚠ Skipping {seq_id}: missing joint_states_grasp.npy")
            continue
        
        joint_states = np.load(joint_states_path)
        
        # Load pointclouds
        pointmap_full_path = os.path.join(sequence_dir, "pointmap_start.pt")
        pointmap_cropped_path = os.path.join(sequence_dir, "pointmap_start_cropped.pt")
        pointmap_cropped_fps_path = os.path.join(sequence_dir, "pointmap_start_cropped_fps.pt")
        
        if not os.path.exists(pointmap_full_path):
            print(f"  ⚠ Skipping {seq_id}: missing pointmap_start.pt")
            continue
        
        pointmap_full = torch.load(pointmap_full_path)
        points_full = pointmap_full["points"].numpy()
        colors_full = pointmap_full["colors"].numpy()
        
        points_cropped = None
        colors_cropped = None
        if os.path.exists(pointmap_cropped_path):
            pointmap_cropped = torch.load(pointmap_cropped_path)
            points_cropped = pointmap_cropped["points"].numpy()
            colors_cropped = pointmap_cropped["colors"].numpy()
        
        points_cropped_fps = None
        colors_cropped_fps = None
        if os.path.exists(pointmap_cropped_fps_path):
            pointmap_cropped_fps = torch.load(pointmap_cropped_fps_path)
            points_cropped_fps = pointmap_cropped_fps["points"].numpy()
            colors_cropped_fps = pointmap_cropped_fps["colors"].numpy()
        
        # Load DINO features if available
        dino_features_path = os.path.join(sequence_dir, "dino_features_fps.pt")
        dino_colors_fps = None
        if os.path.exists(dino_features_path):
            dino_features_fps = torch.load(dino_features_path).numpy()  # (N, 32)
            # Map first 3 PCA components to RGB
            dino_rgb = F.sigmoid(torch.from_numpy(dino_features_fps[:, :3]).mul(2.0)).numpy()
            dino_colors_fps = (dino_rgb * 255).astype(np.uint8)
        
        episodes_data.append({
            'sequence_id': seq_id,
            'joint_states': joint_states,
            'points_full': points_full,
            'colors_full': colors_full,
            'points_cropped': points_cropped,
            'colors_cropped': colors_cropped,
            'points_cropped_fps': points_cropped_fps,
            'colors_cropped_fps': colors_cropped_fps,
            'dino_colors_fps': dino_colors_fps,
        })
        
        print(f"  ✓ Loaded {seq_id}: {len(points_full)} full, "
              f"{len(points_cropped) if points_cropped is not None else 0} cropped, "
              f"{len(points_cropped_fps) if points_cropped_fps is not None else 0} cropped+FPS points")
        
    except Exception as e:
        print(f"  ✗ Failed to load {seq_id}: {e}")
        continue

if len(episodes_data) == 0:
    raise ValueError("No episodes could be loaded!")

print(f"\nSuccessfully loaded {len(episodes_data)} episodes")

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser dataset viewer")
print("=" * 60)
server = viser.ViserServer()

# Load robot URDF once
urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
urdf = yourdfpy.URDF.load(urdf_path)
viser_urdf = ViserUrdf(
    server,
    urdf_or_path=urdf,
    load_meshes=True,
    load_collision_meshes=False,
    collision_mesh_color_override=(1.0, 0.0, 0.0, 0.5),
)

mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])

# Store current episode index
current_episode_idx = [0]  # Use list to allow modification in closure

def update_episode(episode_idx):
    """Update visualization to show the selected episode."""
    if not (0 <= episode_idx < len(episodes_data)):
        return
    
    current_episode_idx[0] = episode_idx
    episode = episodes_data[episode_idx]
    
    # Update robot URDF
    joint_states = episode['joint_states']
    viser_urdf.update_cfg(np.array(joint_states - mujoco_so100_offset))
    
    # Update full pointcloud
    server.scene.add_point_cloud(
        name="/moge_aligned_full",
        points=episode['points_full'].astype(np.float32),
        colors=episode['colors_full'].astype(np.uint8),
        point_size=0.001,
    )
    
    if episode['points_cropped'] is not None:
        server.scene.add_point_cloud(
            name="/moge_aligned_cropped",
            points=episode['points_cropped'].astype(np.float32),
            colors=episode['colors_cropped'].astype(np.uint8),
            point_size=0.002,
        )

    # Update cropped FPS pointcloud if available
    if episode['points_cropped_fps'] is not None:
        server.scene.add_point_cloud(
            name="/moge_aligned_cropped_fps",
            points=episode['points_cropped_fps'].astype(np.float32),
            colors=episode['colors_cropped_fps'].astype(np.uint8),
            point_size=0.003,  # Larger for visibility
        )
        
        # Update DINO feature pointcloud if available
        if episode['dino_colors_fps'] is not None:
            server.scene.add_point_cloud(
                name="/dino_features_fps",
                points=episode['points_cropped_fps'].astype(np.float32),
                colors=episode['dino_colors_fps'].astype(np.uint8),
                point_size=0.003,
            )

# Initialize with first episode
update_episode(0)

# Add slider to switch between episodes
slider = server.gui.add_slider(
    "episode", 
    0, 
    len(episodes_data) - 1, 
    initial_value=0, 
    step=1
)

@slider.on_update
def _(_):
    """Callback when slider value changes."""
    new_idx = int(slider.value)
    if new_idx != current_episode_idx[0]:
        update_episode(new_idx)

print(f"\nAdded slider to switch between {len(episodes_data)} episodes")
print(f"Use slider to navigate through episodes")
print(f"Episode names:")
for i, ep in enumerate(episodes_data):
    print(f"  {i}: {ep['sequence_id']}")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Dataset viewer with {len(episodes_data)} episodes")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

