"""Visualize MoGe pointmaps in robot frame with a slider to switch between episodes."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import viser
import cv2
import matplotlib.pyplot as plt
import trimesh
from scipy.spatial.transform import Rotation as R
from scipy.ndimage import gaussian_filter
import torch.nn.functional as F
from geom_utils import procrustes_alignment

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0

parser = argparse.ArgumentParser()
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
parser.add_argument("--processed_dir", type=str, default="scratch/processed_grasp_dataset_keyboard", 
                    help="Directory with processed episodes")
parser.add_argument("--pred_dir", type=str, default="scratch/pred/3d_keypoint_predictor",
                    help="Directory with predictions")
parser.add_argument("--split", type=str, choices=["train", "val"], default="val",
                    help="Which split to visualize (train or val)")
parser.add_argument("--use_argmax", action="store_true",
                    help="Use argmax on Gaussian-weighted map instead of weighted sum")
args = parser.parse_args()

# Configuration
processed_dir = args.processed_dir
pred_dir = args.pred_dir
max_episodes = args.max_episodes
split = args.split

print("=" * 60)
print(f"Loading predictions from {pred_dir}")
print("=" * 60)

# Load predictions
pred_offset_fields_path = os.path.join(pred_dir, f"{split}_offset_fields_pred.pt")
gt_offset_fields_path = os.path.join(pred_dir, f"{split}_offset_fields_gt.pt")
sequence_ids_path = os.path.join(pred_dir, f"{split}_sequence_ids.txt")

pred_offset_fields = torch.load(pred_offset_fields_path)  # (N, 5, 3, H, W) - downsampled
gt_offset_fields = torch.load(gt_offset_fields_path)  # (N, 5, 3, H, W) - downsampled

with open(sequence_ids_path, 'r') as f:
    sequence_ids = [line.strip() for line in f.readlines()]

print(f"Loaded {len(sequence_ids)} predictions")
print(f"Prediction shape: {pred_offset_fields.shape}")
print(f"GT shape: {gt_offset_fields.shape}")

# Limit to max_episodes
sequences = sequence_ids[:max_episodes]
pred_offset_fields = pred_offset_fields[:max_episodes]
gt_offset_fields = gt_offset_fields[:max_episodes]

print(f"\nVisualizing first {len(sequences)} sequences:")
for i, seq_id in enumerate(sequences):
    print(f"  {i}: {seq_id}")

print("\n" + "=" * 60)
print("Loading episode data and extracting predicted keypoints")
print("=" * 60)

# Load data for all episodes
episodes_data = []

for idx, seq_id in enumerate(sequences):
    sequence_dir = os.path.join(processed_dir, seq_id)
    
    # Load start image to get H, W
    start_img_path = os.path.join(sequence_dir, "start.png")
    start_image = cv2.imread(start_img_path)
    start_image = cv2.cvtColor(start_image, cv2.COLOR_BGR2RGB)
    if start_image.max() <= 1.0:
        start_image = (start_image * 255).astype(np.uint8)
    H, W = start_image.shape[:2]
    
    # Load raw pointmap (camera frame, H*W, 3)
    pointmap_raw_path = os.path.join(sequence_dir, "pointmap_start_raw.pt")
    pointmap_raw = torch.load(pointmap_raw_path)
    points_cam_flat = pointmap_raw["points"].cpu().numpy()  # (H*W, 3) in camera frame
    colors_flat = pointmap_raw["colors"].cpu().numpy()  # (H*W, 3)
    mask_flat = pointmap_raw["mask"].cpu().numpy()  # (H*W,)
    
    # Reshape to (H, W, 3)
    points_cam = points_cam_flat.reshape(H, W, 3)  # (H, W, 3) in camera frame
    
    # Load robot-aligned pointmap (filtered points only)
    pointmap_path = os.path.join(sequence_dir, "pointmap_start.pt")
    pointmap = torch.load(pointmap_path)
    points = pointmap["points"].cpu().numpy()  # (N, 3) in robot frame
    colors = pointmap["colors"].cpu().numpy()  # (N, 3) RGB colors
    
    # Ensure colors are uint8 [0-255]
    if colors.dtype != np.uint8:
        if colors.max() <= 1.0: colors = (colors * 255).astype(np.uint8)
        else: colors = colors.astype(np.uint8)
    
    # Load DINO features and convert to PCA RGB colors
    dino_features_path = os.path.join(sequence_dir, "dino_features.pt")
    dino_features = torch.load(dino_features_path).numpy()  # (N, 32)
    
    # Map first 3 PCA components to RGB
    dino_pca_rgb = dino_features[:, :3]  # (N, 3)
    dino_pca_rgb_normalized = F.sigmoid(torch.from_numpy(dino_pca_rgb).mul(2.0)).numpy()
    dino_colors = (dino_pca_rgb_normalized * 255).astype(np.uint8)
    
    # Load masks
    # robot_mask_path = os.path.join(sequence_dir, "robot_mask.png")
    # human_mask_path = os.path.join(sequence_dir, "human_mask.png")
    moge_edge_mask_path = os.path.join(sequence_dir, "moge_edge_mask.png")
    
    # robot_mask = plt.imread(robot_mask_path)
    # if len(robot_mask.shape) == 3:
    #     robot_mask = robot_mask[:, :, 0]
    # if robot_mask.max() <= 1.0:
    #     robot_mask = (robot_mask * 255).astype(np.uint8)
    
    # human_mask = plt.imread(human_mask_path)
    # if len(human_mask.shape) == 3: human_mask = human_mask[:, :, 0]
    # if human_mask.max() <= 1.0: human_mask = (human_mask * 255).astype(np.uint8)
    
    moge_edge_mask = plt.imread(moge_edge_mask_path)
    if len(moge_edge_mask.shape) == 3: moge_edge_mask = moge_edge_mask[:, :, 0]
    if moge_edge_mask.max() <= 1.0: moge_edge_mask = (moge_edge_mask * 255).astype(np.uint8)
    
    # Load gripper pose
    gripper_pose_path = os.path.join(sequence_dir, "gripper_pose_grasp.npy")
    gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix
    
    # Transform keypoints from gripper local frame to robot frame (GT)
    gripper_rot = gripper_pose[:3, :3]
    gripper_pos = gripper_pose[:3, 3]
    keypoints_robot_gt = (gripper_rot @ KEYPOINTS_LOCAL_M.T).T + gripper_pos.reshape(1, 3)
    
    # Extract predicted keypoints from offset fields
    # Get predicted offset fields for this sample (downsampled to 224x224)
    pred_offsets = pred_offset_fields[idx]  # (5, 3, H_pred, W_pred) - already in (5, 3, H, W) format from model

    H_pred, W_pred = pred_offsets.shape[2], pred_offsets.shape[3]
    
    # Load GT offset fields at full resolution to reconstruct robot frame pointmap
    gt_offsets_full_path = os.path.join(sequence_dir, "offset_fields.pt")
    gt_offsets_full = torch.load(gt_offsets_full_path).numpy()  # (5, H, W, 3) in robot frame
    
    # Reconstruct robot frame pointmap from GT offsets and keypoints
    # offset = keypoint - pointmap_point, so pointmap_point = keypoint - offset
    # All keypoints should give the same pointmap, so use the first one
    keypoint_gt = keypoints_robot_gt[0]  # (3,)
    offset_gt = gt_offsets_full[0]  # (H, W, 3)
    points_robot_full = keypoint_gt.reshape(1, 1, 3) - offset_gt  # (H, W, 3) in robot frame
    
    # Upsample predicted offset fields to full resolution
    # Handle both tensor and numpy array cases
    if isinstance(pred_offsets, torch.Tensor):
        pred_offsets_t = pred_offsets.float()  # (5, 3, H_pred, W_pred)
    else:
        pred_offsets_t = torch.from_numpy(pred_offsets).float()  # (5, 3, H_pred, W_pred)
    
    # Reshape to (15, H_pred, W_pred) for interpolation: 5 keypoints * 3 channels
    pred_offsets_flat = pred_offsets_t.reshape(15, H_pred, W_pred)  # (15, H_pred, W_pred)
    
    # Interpolate to full resolution
    pred_offsets_upsampled = F.interpolate(
        pred_offsets_flat.unsqueeze(0),  # (1, 15, H_pred, W_pred)
        size=(H, W),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)  # (15, H, W)
    
    # Reshape back to (5, 3, H, W) then permute to (5, H, W, 3)
    pred_offsets_upsampled = pred_offsets_upsampled.reshape(5, 3, H, W).permute(0, 2, 3, 1).numpy()  # (5, H, W, 3)
    
    # Extract keypoints using weighted average based on smoothed inverse distance
    keypoints_robot_pred = np.zeros((5, 3), dtype=np.float32)
    weight_masks = []  # Store weight masks for visualization
    
    # Temperature for weighting (lower = sharper, higher = smoother)
    temperature = 0.1
    
    for kp_idx in range(5):
        # Compute keypoint locations: keypoint = pointmap + offset
        keypoint_locations = points_robot_full + pred_offsets_upsampled[kp_idx]  # (H, W, 3)
        
        # Compute offset magnitude
        offset_magnitude = np.linalg.norm(pred_offsets_upsampled[kp_idx], axis=2)  # (H, W)
        
        # Compute inverse distance (1 / distance)
        eps = 1e-6
        inv_dist = 1.0 / (offset_magnitude + eps)  # (H, W)
        
        # Smooth the inverse distance field using Gaussian blur
        inv_dist_smoothed = gaussian_filter(inv_dist, sigma=8.0)  # (H, W)
        
        # Apply temperature to create weights: weights = (inv_dist / temperature) then softmax
        inv_dist_scaled = inv_dist_smoothed / temperature
        # Softmax-like normalization: exp and normalize
        weights = np.exp(inv_dist_scaled - inv_dist_scaled.max())  # (H, W)
        weights = weights / (weights.sum() + eps)  # Normalize to sum to 1
        
        if args.use_argmax:
            # Use argmax on the Gaussian-weighted map
            max_idx = np.unravel_index(np.argmax(weights), weights.shape)
            keypoint_pred = keypoint_locations[max_idx]  # (3,)
        else:
            # Weighted sum of keypoint locations using the weight map
            # keypoint_locations is (H, W, 3), weights is (H, W)
            # We want: sum over i,j of weights[i,j] * keypoint_locations[i,j,:]
            weights_expanded = weights[:, :, np.newaxis]  # (H, W, 1)
            weighted_locations = weights_expanded * keypoint_locations  # (H, W, 3)
            keypoint_pred = np.sum(weighted_locations, axis=(0, 1))  # (3,)
        
        keypoints_robot_pred[kp_idx] = keypoint_pred
        weight_masks.append(weights)  # Store for visualization
    
    # Compute Procrustes transformation from GT keypoints to predicted keypoints
    # This aligns GT keypoints to predicted keypoints, giving us the transformation
    # to apply to the GT gripper to match the predicted keypoint locations
    T_procrustes, scale, rotation, translation = procrustes_alignment(
        keypoints_robot_pred, keypoints_robot_gt  # Align GT to predicted
    )
    
    # Apply transformation to gripper pose to get predicted gripper pose
    # T_procrustes transforms GT keypoints to predicted keypoints, so apply it to gripper
    gripper_pose_pred = T_procrustes @ gripper_pose
    
    # Store predicted offset fields for vector field visualization
    # We'll sample from the full HxW offset fields for visualization
    episodes_data.append({
        'sequence_id': seq_id,
        'points': points,
        'colors': colors,
        'dino_colors': dino_colors,
        'start_image': start_image,
        # 'robot_mask': robot_mask,
        # 'human_mask': human_mask,
        'moge_edge_mask': moge_edge_mask,
        'gripper_pose': gripper_pose,
        'gripper_pose_pred': gripper_pose_pred,
        'keypoints_robot_gt': keypoints_robot_gt,
        'keypoints_robot_pred': keypoints_robot_pred,
        'points_robot_full': points_robot_full,  # Full HxW pointmap
        'weight_masks': weight_masks,  # (5, H, W) weight masks for each keypoint
        'H': H,
        'W': W,
    })
    
    print(f"  ✓ Loaded {seq_id}: {len(points)} points, extracted {len(keypoints_robot_pred)} predicted keypoints")

print(f"\nSuccessfully loaded {len(episodes_data)} episodes")

# Load gripper mesh once
print("\n" + "=" * 60)
print("Loading gripper mesh")
print("=" * 60)
#fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
fixed_gripper_stl_path = "robot_models/so100_blender_testings/kp_testing_fixedgrip.stl"
fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
if isinstance(fixed_gripper_mesh, trimesh.Scene):
    fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]
bounds = fixed_gripper_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    fixed_gripper_mesh.apply_scale(0.001)
print(f"  ✓ Loaded gripper mesh: {len(fixed_gripper_mesh.vertices)} vertices")

# Load ball mesh for keypoint visualization
print("\n" + "=" * 60)
print("Loading ball mesh for keypoints")
print("=" * 60)
ball_stl_path = "robot_models/so100_blender_testings/ball.stl"
ball_mesh = trimesh.load(ball_stl_path)
if isinstance(ball_mesh, trimesh.Scene):
    ball_mesh = list(ball_mesh.geometry.values())[0]
bounds = ball_mesh.bounds
max_extent = np.max(bounds[1] - bounds[0])
if max_extent > 1.0:
    ball_mesh.apply_scale(0.001)
print(f"  ✓ Loaded ball mesh: {len(ball_mesh.vertices)} vertices")

# Launch viser visualization
print("\n" + "=" * 60)
print("Launching viser dataset viewer")
print("=" * 60)
server = viser.ViserServer()

# Store current episode index
current_episode_idx = [0]  # Use list to allow modification in closure

def update_episode(episode_idx):
    """Update visualization to show the selected episode."""
    current_episode_idx[0] = episode_idx
    episode = episodes_data[episode_idx]
    
    # Update pointcloud with RGB colors (robot frame)
    server.scene.add_point_cloud(
        name="/moge_pointmap_robot_rgb",
        points=episode['points'].astype(np.float32),
        colors=episode['colors'].astype(np.uint8),
        point_size=0.002,
    )
    
    # Update pointcloud with DINO PCA colors (robot frame)
    #server.scene.add_point_cloud(
    #    name="/moge_pointmap_robot_dino",
    #    points=episode['points'].astype(np.float32),
    #    colors=episode['dino_colors'].astype(np.uint8),
    #    point_size=0.002,
    #)
    
    # Update gripper mesh - GT (blue/cyan)
    if fixed_gripper_mesh is not None:
        gripper_pose = episode['gripper_pose']
        pos = gripper_pose[:3, 3]
        rot = gripper_pose[:3, :3]
        quat = R.from_matrix(rot).as_quat()  # (x, y, z, w)
        
        try:
            server.scene.add_mesh_trimesh(
                name="/gripper_gt",
                mesh=fixed_gripper_mesh,
                wxyz=quat[[3, 0, 1, 2]],  # (w, x, y, z)
                position=pos,
            )
        except:
            # Fallback: transform vertices manually
            vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
            transformed_vertices = (gripper_pose @ vertices_homogeneous.T).T[:, :3]
            server.scene.add_mesh(
                name="/gripper_gt",
                vertices=transformed_vertices.astype(np.float32),
                faces=fixed_gripper_mesh.faces.astype(np.int32),
                color=(100, 150, 200, 255),  # Blue/cyan
            )
        
        # Update predicted gripper mesh (orange/red)
        if 'gripper_pose_pred' in episode:
            gripper_pose_pred = episode['gripper_pose_pred']
            pos_pred = gripper_pose_pred[:3, 3]
            rot_pred = gripper_pose_pred[:3, :3]
            quat_pred = R.from_matrix(rot_pred).as_quat()  # (x, y, z, w)
            
            # Create a colored copy of the mesh for the predicted gripper
            pred_gripper_mesh = fixed_gripper_mesh.copy()
            pred_gripper_mesh.visual.vertex_colors = [255, 165, 0, 255]  # Orange
            
            try:
                server.scene.add_mesh_trimesh(
                    name="/gripper_pred",
                    mesh=pred_gripper_mesh,
                    wxyz=quat_pred[[3, 0, 1, 2]],  # (w, x, y, z)
                    position=pos_pred,
                )
            except:
                # Fallback: transform vertices manually
                vertices_homogeneous = np.hstack([pred_gripper_mesh.vertices, np.ones((pred_gripper_mesh.vertices.shape[0], 1))])
                transformed_vertices = (gripper_pose_pred @ vertices_homogeneous.T).T[:, :3]
                server.scene.add_mesh(
                    name="/gripper_pred",
                    vertices=transformed_vertices.astype(np.float32),
                    faces=pred_gripper_mesh.faces.astype(np.int32),
                    color=(255, 165, 0, 255),  # Orange
                )
    
    # Update keypoint spheres - GT (green) and Predicted (red)
    if ball_mesh is not None:
        # GT keypoints (green)
        if 'keypoints_robot_gt' in episode and episode['keypoints_robot_gt'] is not None:
            keypoints_robot_gt = episode['keypoints_robot_gt']
            for i, kp_pos in enumerate(keypoints_robot_gt):
                server.scene.add_mesh_trimesh(
                    name=f"/keypoint_gt_{i}",
                    mesh=ball_mesh,
                    wxyz=(1.0, 0.0, 0.0, 0.0),  # No rotation
                    position=kp_pos.astype(np.float32),
                )
        
        # Predicted keypoints (red)
        if 'keypoints_robot_pred' in episode and episode['keypoints_robot_pred'] is not None:
            keypoints_robot_pred = episode['keypoints_robot_pred']
            for i, kp_pos in enumerate(keypoints_robot_pred):
                # Create a red-colored copy of the mesh
                red_ball_mesh = ball_mesh.copy()
                red_ball_mesh.visual.vertex_colors = [255, 0, 0, 255]  # Red
                server.scene.add_mesh_trimesh(
                    name=f"/keypoint_pred_{i}",
                    mesh=red_ball_mesh,
                    wxyz=(1.0, 0.0, 0.0, 0.0),  # No rotation
                    position=kp_pos.astype(np.float32),
                )
    

# Initialize with first episode
update_episode(0)

# Add slider to switch between episodes
slider = server.gui.add_slider(
    "episode", 
    0, 
    len(episodes_data) - 1, 
    initial_value=0, 
    step=1
)

@slider.on_update
def _(_):
    """Callback when slider value changes."""
    new_idx = int(slider.value)
    if new_idx != current_episode_idx[0]:
        update_episode(new_idx)

print(f"\nAdded slider to switch between {len(episodes_data)} episodes")
print(f"Use slider to navigate through episodes")
print(f"Episode names:")
for i, ep in enumerate(episodes_data):
    print(f"  {i}: {ep['sequence_id']}")

print("\n" + "=" * 60)
print(f"Viser server running at http://localhost:8080")
print(f"Dataset viewer with {len(episodes_data)} episodes")
print(f"Press Ctrl+C to exit")
print("=" * 60)

try:
    while True:
        import time
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

print("\nDone!")

