"""Test script to render robot overlays on RGB images using joint states."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import mujoco
import colorsys

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import (
    get_link_poses_from_robot,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    detect_and_set_link_poses,
)
from scipy.spatial.transform import Rotation as R


def get_gripper_position_from_joint_state(joint_state, model, robot_config):
    """Get gripper position from joint state using MuJoCo."""
    data = mujoco.MjData(model)
    data.qpos[:] = joint_state
    data.ctrl[:] = joint_state[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Position exoskeleton meshes
    link_poses = get_link_poses_from_robot(robot_config, model, data)
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Get gripper position from exoskeleton mesh
    exo_mesh_body_name = "fixed_gripper_exo_mesh"
    exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
    exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
    gripper_pos = data.mocap_pos[exo_mesh_mocap_id].copy()
    
    return gripper_pos


def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point in world frame to 2D image coordinates.
    
    Args:
        point_3d: (3,) array in world frame
        camera_pose: (4, 4) transformation matrix from world to camera
        cam_K: (3, 3) camera intrinsic matrix
    
    Returns:
        point_2d: (2,) array of image coordinates, or None if behind camera
    """
    # Transform point to camera frame
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = (camera_pose @ point_3d_h)[:3]
    
    # Check if point is behind camera
    if point_cam[2] <= 0:
        return None
    
    # Project to image plane
    point_2d_h = cam_K @ point_cam
    point_2d = point_2d_h[:2] / point_2d_h[2]
    
    return point_2d


def draw_keypoint_on_image(img, point_2d, color=(255, 0, 0), size=10):
    """Draw a keypoint on an image."""
    if point_2d is None:
        return img
    
    H, W = img.shape[:2]
    x, y = int(point_2d[0]), int(point_2d[1])
    
    # Check if point is within image bounds
    if 0 <= x < W and 0 <= y < H:
        # Draw circle
        cv2.circle(img, (x, y), size, color, -1)
        cv2.circle(img, (x, y), size + 2, (255, 255, 0), 2)
        # Draw cross
        cv2.line(img, (x - size, y), (x + size, y), (255, 255, 255), 2)
        cv2.line(img, (x, y - size), (x, y + size), (255, 255, 255), 2)
    
    return img


if __name__ == "__main__":
    # Configuration
    for i in range(1,20):
        episode_dir = Path("scratch/parsed_episodes_cup_synch/episode_%03d"% i) 
        
        if not episode_dir.exists():
            print(f"Error: Episode directory not found: {episode_dir}")
            sys.exit(1)
        
        print(f"Loading episode: {episode_dir.name}")
        
        # Setup MuJoCo model (use base robot config, not combined with alignment board)
        robot_config = SO100AdhesiveConfig()
        mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
        
        # Find all timesteps
        image_files = sorted(episode_dir.glob("*.png"))
        timesteps = []
        for img_path in image_files:
            timestep_str = img_path.stem
            joint_path = episode_dir / f"{timestep_str}.npy"
            if joint_path.exists():
                timesteps.append(timestep_str)
        
        print(f"Found {len(timesteps)} timesteps")
        
        # Load images and render robot overlays
        images = []
        overlays = []
        rendered_images = []
        gripper_positions_2d = []  # Store 2D positions for trajectory visualization
        
        # Create MuJoCo data for rendering
        mj_data = mujoco.MjData(mj_model)
        
        for timestep_str in timesteps:
            # Load image
            image_path = episode_dir / f"{timestep_str}.png"
            rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
            if rgb.max() <= 1.0:
                rgb = (rgb * 255).astype(np.uint8)
            images.append(rgb)
            
            # Load joint state
            joint_path = episode_dir / f"{timestep_str}.npy"
            joint_state = np.load(joint_path)
            
            # Set joint state in MuJoCo
            mj_data.qpos[:] = joint_state
            mj_data.ctrl[:] = joint_state[:len(mj_data.ctrl)]
            mujoco.mj_forward(mj_model, mj_data)
            
            # Detect camera pose and intrinsics from ArUco markers in the image
            try:
                link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                    rgb, mj_model, mj_data, robot_config, cam_K=None
                )
                
                # Position exoskeleton meshes from detected link poses
                # detect_and_set_link_poses already calls position_exoskeleton_meshes internally,
                # but we need to ensure the joint state is set correctly
                mj_data.qpos[:] = joint_state
                mj_data.ctrl[:] = joint_state[:len(mj_data.ctrl)]
                mujoco.mj_forward(mj_model, mj_data)
                position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
                
                # Render robot from camera pose
                H, W = rgb.shape[:2]
                rendered = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, H, W)
                rendered_images.append(rendered)
                
                # Create overlay
                overlay = (rgb.astype(float) * 0.5 + rendered.astype(float) * 0.5).astype(np.uint8)
                
                # Get gripper position and project to 2D
                gripper_pos_robot = get_gripper_position_from_joint_state(
                    joint_state, mj_model, robot_config
                )
                point_2d = project_3d_to_2d(gripper_pos_robot, camera_pose_world, cam_K)
                gripper_positions_2d.append(point_2d)
                
                # Draw keypoint on overlay with color based on timestep
                # Use a color map (rainbow) for unique colors per timestep
                num_timesteps = len(timesteps)
                timestep_idx = len(overlays)  # Current index in the sequence
                hue = timestep_idx / max(1, num_timesteps - 1)  # Normalize to [0, 1]
                # Convert HSV to RGB (rainbow: hue varies, saturation=1, value=1)
                rgb_color = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
                color = tuple(int(c * 255) for c in rgb_color)
                
                overlay = draw_keypoint_on_image(overlay, point_2d, color=color, size=20)
                
                overlays.append(overlay)
            except Exception as e:
                print(f"  ⚠ Failed to detect camera pose for {timestep_str}: {e}")
                rendered_images.append(None)
                overlays.append(None)
                gripper_positions_2d.append(None)
        
        # Create visualization with robot overlays
        num_images = len(images)
        cols = 4
        rows = (num_images + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(20, 5 * rows))
        if rows == 1:
            axes = axes.reshape(1, -1)
        axes = axes.flatten()
        
        for i, (img, overlay, timestep_str) in enumerate(zip(images, overlays, timesteps)):
            ax = axes[i]
            # Show overlay if available, otherwise show original image
            display_img = overlay if overlay is not None else img
            ax.imshow(display_img)
            ax.axis('off')
            ax.set_title(f"Timestep {timestep_str}")
        
        # Hide unused subplots
        for i in range(num_images, len(axes)):
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.savefig("scratch/test_2d_keypoint_traj.png", dpi=150, bbox_inches='tight')
        print(f"\nSaved visualization to: scratch/test_2d_keypoint_traj.png")
        
        # Also create a side-by-side comparison (RGB, Rendered, Overlay)
        if len(images) > 0 and any(o is not None for o in overlays):
            # Create a grid showing RGB, rendered, and overlay for first few frames
            num_show = min(4, len(images))
            fig2, axes2 = plt.subplots(num_show, 3, figsize=(15, 5 * num_show))
            if num_show == 1:
                axes2 = axes2.reshape(1, -1)
            
            for i in range(num_show):
                if overlays[i] is not None:
                    axes2[i, 0].imshow(images[i])
                    axes2[i, 0].set_title(f"RGB - {timesteps[i]}")
                    axes2[i, 0].axis('off')
                    
                    axes2[i, 1].imshow(rendered_images[i])
                    axes2[i, 1].set_title(f"Rendered - {timesteps[i]}")
                    axes2[i, 1].axis('off')
                    
                    axes2[i, 2].imshow(overlays[i])
                    axes2[i, 2].set_title(f"Overlay - {timesteps[i]}")
                    axes2[i, 2].axis('off')
            
            plt.tight_layout()
            plt.savefig("scratch/test_2d_keypoint_traj_comparison.png", dpi=150, bbox_inches='tight')
            print(f"Saved comparison to: scratch/test_2d_keypoint_traj_comparison.png")
        
        # Create trajectory visualization on start image
        if len(images) > 0 and any(p is not None for p in gripper_positions_2d):
            # Use the first image as base
            base_img = images[0].copy()
            
            # Draw trajectory with rainbow colors
            valid_points = [(i, p) for i, p in enumerate(gripper_positions_2d) if p is not None]
            if len(valid_points) > 1:
                points_array = np.array([p for _, p in valid_points])
                H, W = base_img.shape[:2]
                
                # Filter points within image bounds
                valid_mask = (points_array[:, 0] >= 0) & (points_array[:, 0] < W) & \
                            (points_array[:, 1] >= 0) & (points_array[:, 1] < H)
                points_array = points_array[valid_mask]
                valid_indices = [valid_points[i][0] for i in range(len(valid_points)) if valid_mask[i]]
                
                if len(points_array) > 1:
                    # Draw trajectory line segments with gradient colors
                    for j in range(len(points_array) - 1):
                        pt1 = points_array[j].astype(int)
                        pt2 = points_array[j + 1].astype(int)
                        
                        # Get color for this segment (based on the first point's timestep)
                        idx1 = valid_indices[j]
                        hue = idx1 / max(1, len(timesteps) - 1)
                        rgb_color = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
                        color = tuple(int(c * 255) for c in rgb_color)
                        
                        cv2.line(base_img, tuple(pt1), tuple(pt2), color, 3)
                    
                    # Draw keypoints with unique colors
                    for j, pt in enumerate(points_array):
                        pt_int = pt.astype(int)
                        idx = valid_indices[j]
                        hue = idx / max(1, len(timesteps) - 1)
                        rgb_color = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
                        color = tuple(int(c * 255) for c in rgb_color)
                        
                        # Draw filled circle
                        cv2.circle(base_img, tuple(pt_int), 8, color, -1)
                        # Draw edge
                        cv2.circle(base_img, tuple(pt_int), 10, (255, 255, 255), 2)
            
            plt.figure(figsize=(12, 8))
            plt.imshow(base_img)
            plt.axis('off')
            plt.title(f"Gripper Trajectory Overlay - {episode_dir.name}")
            plt.tight_layout()
            plt.savefig("scratch/test_2d_keypoint_traj_overlay.png", dpi=150, bbox_inches='tight')
            print(f"Saved trajectory overlay to: scratch/test_2d_keypoint_traj_overlay.png")
        
        plt.show()
        
        print(f"\nProcessed {len(timesteps)} timesteps")
        print(f"Valid overlays: {sum(1 for o in overlays if o is not None)}/{len(overlays)}")

