"""Test script to render gripper position on grasp frame from processed keyboard grasp dataset."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import mujoco
import colorsys

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import (
    detect_and_set_link_poses,
)

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0


def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point in world frame to 2D image coordinates.
    
    Args:
        point_3d: (3,) array in world frame
        camera_pose: (4, 4) transformation matrix from world to camera
        cam_K: (3, 3) camera intrinsic matrix
    
    Returns:
        point_2d: (2,) array of image coordinates, or None if behind camera
    """
    # Transform point to camera frame
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = (camera_pose @ point_3d_h)[:3]
    
    # Check if point is behind camera
    if point_cam[2] <= 0:
        return None
    
    # Project to image plane
    point_2d_h = cam_K @ point_cam
    point_2d = point_2d_h[:2] / point_2d_h[2]
    
    return point_2d


def draw_keypoint_on_image(img, point_2d, color=(255, 0, 0), size=15):
    """Draw a keypoint on an image."""
    if point_2d is None:
        return img
    
    H, W = img.shape[:2]
    x, y = int(point_2d[0]), int(point_2d[1])
    
    # Check if point is within image bounds
    if 0 <= x < W and 0 <= y < H:
        # Draw filled circle
        cv2.circle(img, (x, y), size, color, -1)
        # Draw white border for visibility
        cv2.circle(img, (x, y), size + 2, (255, 255, 255), 2)
        # Draw cross for better visibility
        cv2.line(img, (x - size, y), (x + size, y), (255, 255, 255), 2)
        cv2.line(img, (x, y - size), (x, y + size), (255, 255, 255), 2)
    
    return img


if __name__ == "__main__":
    # Configuration
    dataset_dir = Path("scratch/processed_grasp_dataset_keyboard")
    
    if not dataset_dir.exists():
        print(f"Error: Dataset directory not found: {dataset_dir}")
        sys.exit(1)
    
    # Find all sequence directories
    sequence_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir()])
    
    if len(sequence_dirs) == 0:
        print(f"Error: No sequences found in {dataset_dir}")
        sys.exit(1)
    
    print(f"Found {len(sequence_dirs)} sequences")
    
    # Setup MuJoCo model
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)
    
    # Process each sequence
    images_with_keypoints = []
    sequence_ids = []
    
    def render_keypoints_on_image(rgb_image, gripper_pose, camera_pose_world, cam_K):
        """Render all keypoints and gripper origin on an image."""
        rgb_with_keypoints = rgb_image.copy()
        
        # Extract gripper rotation and position
        gripper_rot = gripper_pose[:3, :3]
        gripper_pos = gripper_pose[:3, 3]  # Gripper origin/center
        
        # Project and draw gripper origin (center) - use white color
        origin_2d = project_3d_to_2d(gripper_pos, camera_pose_world, cam_K)
        if origin_2d is not None:
            rgb_with_keypoints = draw_keypoint_on_image(rgb_with_keypoints, origin_2d, color=(255, 255, 255), size=15)
        
        # Project and draw all 5 keypoints with different colors
        for kp_idx in range(5):
            # Transform keypoint from gripper local frame to robot frame
            keypoint_local = KEYPOINTS_LOCAL_M[kp_idx]
            keypoint_robot = gripper_rot @ keypoint_local + gripper_pos
            
            # Project to 2D
            point_2d = project_3d_to_2d(keypoint_robot, camera_pose_world, cam_K)
            
            if point_2d is not None:
                # Use rainbow color based on keypoint index
                hue = kp_idx / 4.0  # Normalize to [0, 1]
                rgb_color = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
                color = tuple(int(c * 255) for c in rgb_color)
                
                # Draw keypoint
                rgb_with_keypoints = draw_keypoint_on_image(rgb_with_keypoints, point_2d, color=color, size=12)
        
        return rgb_with_keypoints
    
    for seq_dir in sequence_dirs:
        seq_id = seq_dir.name
        print(f"Processing {seq_id}...")
        
        # Load start and grasp images
        start_img_path = seq_dir / "start.png"
        grasp_img_path = seq_dir / "grasp.png"
        
        if not start_img_path.exists():
            print(f"  ⚠ Skipping {seq_id}: start.png not found")
            continue
        if not grasp_img_path.exists():
            print(f"  ⚠ Skipping {seq_id}: grasp.png not found")
            continue
        
        rgb_start = cv2.cvtColor(cv2.imread(str(start_img_path)), cv2.COLOR_BGR2RGB)
        rgb_grasp = cv2.cvtColor(cv2.imread(str(grasp_img_path)), cv2.COLOR_BGR2RGB)
        if rgb_start.max() <= 1.0:
            rgb_start = (rgb_start * 255).astype(np.uint8)
        if rgb_grasp.max() <= 1.0:
            rgb_grasp = (rgb_grasp * 255).astype(np.uint8)
        
        # Load gripper pose (4x4 transformation matrix)
        gripper_pose_path = seq_dir / "gripper_pose_grasp.npy"
        if not gripper_pose_path.exists():
            print(f"  ⚠ Skipping {seq_id}: gripper_pose_grasp.npy not found")
            continue
        
        gripper_pose = np.load(gripper_pose_path)  # (4, 4)
        
        # Estimate camera pose and intrinsics from start image
        try:
            mj_data.qpos[:] = 0  # Reset joint state
            mj_data.ctrl[:] = 0
            mujoco.mj_forward(mj_model, mj_data)
            
            link_poses_start, camera_pose_world_start, cam_K_start, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                rgb_start, mj_model, mj_data, robot_config, cam_K=None
            )
        except Exception as e:
            print(f"  ⚠ Failed to detect camera pose from start image for {seq_id}: {e}")
            continue
        
        # Estimate camera pose and intrinsics from grasp image
        try:
            mj_data.qpos[:] = 0  # Reset joint state
            mj_data.ctrl[:] = 0
            mujoco.mj_forward(mj_model, mj_data)
            
            link_poses_grasp, camera_pose_world_grasp, cam_K_grasp, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                rgb_grasp, mj_model, mj_data, robot_config, cam_K=None
            )
        except Exception as e:
            print(f"  ⚠ Failed to detect camera pose from grasp image for {seq_id}: {e}")
            continue
        
        # Render keypoints on start image
        rgb_start_with_keypoints = render_keypoints_on_image(rgb_start, gripper_pose, camera_pose_world_start, cam_K_start)
        
        # Render keypoints on grasp image
        rgb_grasp_with_keypoints = render_keypoints_on_image(rgb_grasp, gripper_pose, camera_pose_world_grasp, cam_K_grasp)
        
        # Save rendered images to sequence directory
        start_output_path = seq_dir / "start_kprender.png"
        grasp_output_path = seq_dir / "grasp_kprender.png"
        
        cv2.imwrite(str(start_output_path), cv2.cvtColor(rgb_start_with_keypoints, cv2.COLOR_RGB2BGR))
        cv2.imwrite(str(grasp_output_path), cv2.cvtColor(rgb_grasp_with_keypoints, cv2.COLOR_RGB2BGR))
        
        print(f"  ✓ Saved {start_output_path.name} and {grasp_output_path.name}")
        
        # Add to visualization grid (use grasp image)
        images_with_keypoints.append(rgb_grasp_with_keypoints)
        sequence_ids.append(seq_id)
    
    # Create visualization grid
    num_images = len(images_with_keypoints)
    if num_images == 0:
        print("No valid images to display")
        sys.exit(1)
    
    cols = 4
    rows = (num_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(20, 5 * rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    for i, (img, seq_id) in enumerate(zip(images_with_keypoints, sequence_ids)):
        ax = axes[i]
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"{seq_id}")
    
    # Hide unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig("scratch/test_2d_keypoint_grasp_render.png", dpi=150, bbox_inches='tight')
    print(f"\nSaved visualization to: scratch/test_2d_keypoint_grasp_render.png")
    
    plt.show()
    
    print(f"\nProcessed {num_images}/{len(sequence_dirs)} sequences")

