"""Visualize MoGe pointmaps with DINO features in viser."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
from torchvision.transforms import functional as TF
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", "-i", type=str, required=True,
                    help="Input directory with parsed episodes")
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
args = parser.parse_args()

# Configuration
input_dir = Path(args.input_dir)
max_episodes = args.max_episodes

print("=" * 60)
print(f"Finding episodes in {input_dir}")
print("=" * 60)

# Find all available episodes
if not input_dir.exists():
    raise FileNotFoundError(f"Input directory not found: {input_dir}")

all_episodes = sorted([d for d in input_dir.iterdir() 
                       if d.is_dir() and d.name.startswith("episode_")])
episodes = all_episodes[:max_episodes]

if len(episodes) == 0:
    raise ValueError(f"No episodes found in {input_dir}")

print(f"Found {len(all_episodes)} total episodes")
print(f"Loading first {len(episodes)} episodes:")
for i, ep_dir in enumerate(episodes):
    print(f"  {i}: {ep_dir.name}")

print("\n" + "=" * 60)
print("Loading episode data")
print("=" * 60)

# Load data for all episodes
episodes_data = []

for ep_dir in episodes:
    try:
        # Find all timesteps in this episode
        image_files = sorted([f for f in ep_dir.glob("*.png") if f.stem.isdigit()])
        
        if len(image_files) == 0:
            print(f"  ⚠ Skipping {ep_dir.name}: no images found")
            continue
        
        # Load data for each timestep
        timestep_data = []
        
        for img_path in image_files:
            timestep_str = img_path.stem
            
            # Load MoGe pointmap (robot-frame)
            pointmap_path = ep_dir / f"moge_pointmap_{timestep_str}.pt"
            moge_mask_path = ep_dir / f"moge_mask_{timestep_str}.npy"
            
            if not pointmap_path.exists():
                continue  # Skip timesteps without MoGe pointmap
            
            pointmap = torch.load(pointmap_path, weights_only=False)
            points_robot = pointmap["points"].numpy()  # (H, W, 3)
            colors_rgb = pointmap["colors"].numpy()  # (H, W, 3)
            mask = pointmap["mask"].numpy().astype(bool)  # (H, W)
            
            # Load depth mask if available
            moge_mask = None
            if moge_mask_path.exists():
                moge_mask = np.load(moge_mask_path).astype(bool)
            
            H, W = points_robot.shape[:2]
            
            # Load DINO features and upsample to MoGe resolution
            dino_features_hw_path = ep_dir / f"dino_features_hw_{timestep_str}.pt"
            dino_colors = None
            
            if dino_features_hw_path.exists():
                dino_features_hw = torch.load(dino_features_hw_path, weights_only=False)  # (H_dino, W_dino, dino_feat_dim)
                
                # If DINO features are at different resolution, upsample to MoGe resolution
                if isinstance(dino_features_hw, torch.Tensor):
                    dino_features_hw = dino_features_hw.numpy()
                
                H_dino, W_dino = dino_features_hw.shape[:2]
                
                if H_dino != H or W_dino != W:
                    # Upsample DINO features to MoGe resolution
                    dino_features_tensor = torch.from_numpy(dino_features_hw).permute(2, 0, 1).float()  # (dino_feat_dim, H_dino, W_dino)
                    dino_features_upsampled = TF.resize(
                        dino_features_tensor,
                        (H, W),
                        interpolation=TF.InterpolationMode.BILINEAR
                    ).permute(1, 2, 0).numpy()  # (H, W, dino_feat_dim)
                else:
                    dino_features_upsampled = dino_features_hw
                
                # Map first 3 PCA components to RGB for visualization
                dino_rgb = F.sigmoid(torch.from_numpy(dino_features_upsampled[:, :, :3]).mul(2.0)).numpy()
                dino_colors = (dino_rgb * 255).astype(np.uint8)  # (H, W, 3)
            
            timestep_data.append({
                'timestep': timestep_str,
                'points_robot': points_robot,
                'colors_rgb': colors_rgb,
                'dino_colors': dino_colors,
                'mask': mask,
                'moge_mask': moge_mask,
            })
        
        if len(timestep_data) == 0:
            print(f"  ⚠ Skipping {ep_dir.name}: no valid timesteps with MoGe pointmaps")
            continue
        
        episodes_data.append({
            'episode_id': ep_dir.name,
            'timesteps': timestep_data,
        })
        
        print(f"  ✓ Loaded {ep_dir.name}: {len(timestep_data)} timesteps")
        
    except Exception as e:
        print(f"  ✗ Failed to load {ep_dir.name}: {e}")
        import traceback
        traceback.print_exc()
        continue

if len(episodes_data) == 0:
    raise ValueError("No episodes could be loaded!")

print(f"\nSuccessfully loaded {len(episodes_data)} episodes")

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser pointmap viewer")
print("=" * 60)
server = viser.ViserServer()

# Store current episode and timestep indices
current_episode_idx = [0]  # Use list to allow modification in closure
current_timestep_idx = [0]

def update_visualization(episode_idx, timestep_idx):
    """Update visualization to show the selected episode and timestep."""
    if not (0 <= episode_idx < len(episodes_data)):
        return
    
    episode = episodes_data[episode_idx]
    
    if not (0 <= timestep_idx < len(episode['timesteps'])):
        return
    
    current_episode_idx[0] = episode_idx
    current_timestep_idx[0] = timestep_idx
    
    timestep_data = episode['timesteps'][timestep_idx]
    
    points_robot = timestep_data['points_robot']  # (H, W, 3)
    colors_rgb = timestep_data['colors_rgb']  # (H, W, 3)
    dino_colors = timestep_data['dino_colors']  # (H, W, 3) or None
    mask = timestep_data['mask']  # (H, W)
    
    # Flatten to point cloud format
    H, W = points_robot.shape[:2]
    points_flat = points_robot.reshape(-1, 3)  # (H*W, 3)
    colors_flat = colors_rgb.reshape(-1, 3)  # (H*W, 3)
    mask_flat = mask.reshape(-1)  # (H*W,)
    
    # Filter by mask
    valid_mask = mask_flat
    valid_points = points_flat[valid_mask]  # (N, 3)
    valid_colors = colors_flat[valid_mask]  # (N, 3)
    
    # Update RGB pointcloud
    server.scene.add_point_cloud(
        name="/moge_pointmap_rgb",
        points=valid_points.astype(np.float32),
        colors=valid_colors.astype(np.uint8),
        point_size=0.002,
    )
    
    # Update DINO feature pointcloud if available
    if dino_colors is not None:
        dino_colors_flat = dino_colors.reshape(-1, 3)  # (H*W, 3)
        valid_dino_colors = dino_colors_flat[valid_mask]  # (N, 3)
        
        server.scene.add_point_cloud(
            name="/moge_pointmap_dino",
            points=valid_points.astype(np.float32),
            colors=valid_dino_colors.astype(np.uint8),
            point_size=0.002,
        )

# Initialize with first episode and timestep
update_visualization(0, 0)

# Add slider to switch between episodes
episode_slider = server.gui.add_slider(
    "episode", 
    0, 
    len(episodes_data) - 1, 
    initial_value=0, 
    step=1
)

# Add slider to switch between timesteps (will be updated when episode changes)
max_timesteps = max(len(ep['timesteps']) for ep in episodes_data) if episodes_data else 1
timestep_slider = server.gui.add_slider(
    "timestep",
    0,
    max(max_timesteps - 1, 0),
    initial_value=0,
    step=1
)

def update_timestep_slider_max(episode_idx):
    """Update timestep slider max value based on current episode."""
    if 0 <= episode_idx < len(episodes_data):
        max_t = len(episodes_data[episode_idx]['timesteps']) - 1
        timestep_slider.max = max(max_t, 0)
        # Ensure current timestep index is valid
        if current_timestep_idx[0] > max_t:
            current_timestep_idx[0] = 0
            timestep_slider.value = 0

@episode_slider.on_update
def _(_):
    """Callback when episode slider value changes."""
    new_ep_idx = int(episode_slider.value)
    if new_ep_idx != current_episode_idx[0]:
        # Update timestep slider max value
        update_timestep_slider_max(new_ep_idx)
        update_visualization(new_ep_idx, int(timestep_slider.value))

@timestep_slider.on_update
def _(_):
    """Callback when timestep slider value changes."""
    new_t_idx = int(timestep_slider.value)
    if new_t_idx != current_timestep_idx[0]:
        current_timestep_idx[0] = new_t_idx
        update_visualization(current_episode_idx[0], new_t_idx)

print(f"\nAdded sliders to switch between {len(episodes_data)} episodes and timesteps")
print(f"Episode names:")
for i, ep in enumerate(episodes_data):
    print(f"  {i}: {ep['episode_id']} ({len(ep['timesteps'])} timesteps)")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Pointmap viewer with {len(episodes_data)} episodes")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")
