"""Render 3D keypoint predictions using MuJoCo."""
import argparse
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import mujoco
import xml.etree.ElementTree as ET
from scipy.spatial.transform import Rotation as R
import torch.nn.functional as F
from geom_utils import procrustes_alignment
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import combine_xmls, get_link_poses_from_robot, position_exoskeleton_meshes, render_from_camera_pose, detect_and_set_link_poses
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0

def create_mujoco_model_with_grippers():
    """Create MuJoCo model with GT and predicted gripper meshes."""
    SO100AdhesiveConfig.exo_alpha = 1.0
    SO100AdhesiveConfig.exo_link_alpha = 1.0
    SO100AdhesiveConfig.aruco_alpha = 1.0
    robot_config = SO100AdhesiveConfig()
    combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
    
    # Add gripper meshes (GT and predicted)
    root = ET.fromstring(combined_xml)
    asset = root.find('asset')
    if asset is None:
        asset = ET.SubElement(root, 'asset')
    
    gripper_mesh = ET.SubElement(asset, 'mesh', {
        'name': 'kptest_fixed_grip',
        'file': '../../so100_blender_testings/kp_testing_fixedgrip.stl',
        'scale': '.001 .001 .001'
    })
    
    worldbody = root.find('worldbody')
    if worldbody is None:
        worldbody = ET.SubElement(root, 'worldbody')
    
    # GT gripper (blue)
    gt_gripper_body = ET.SubElement(worldbody, 'body', {
        'mocap': 'true',
        'name': 'gripper_gt'
    })
    ET.SubElement(gt_gripper_body, 'geom', {
        'type': 'mesh',
        'mesh': 'kptest_fixed_grip',
        'contype': '0',
        'conaffinity': '0',
        'rgba': '0.2 0.6 0.9 0.8'
    })
    
    # Predicted gripper (orange)
    pred_gripper_body = ET.SubElement(worldbody, 'body', {
        'mocap': 'true',
        'name': 'gripper_pred'
    })
    ET.SubElement(pred_gripper_body, 'geom', {
        'type': 'mesh',
        'mesh': 'kptest_fixed_grip',
        'contype': '0',
        'conaffinity': '0',
        'rgba': '1.0 0.65 0.0 0.8'  # Orange
    })
    
    combined_xml = ET.tostring(root, encoding='unicode')
    mj_model = mujoco.MjModel.from_xml_string(combined_xml)
    return mj_model, SO100AdhesiveConfig()

parser = argparse.ArgumentParser()
parser.add_argument("--max_episodes", type=int, default=30, help="Maximum number of episodes to load")
parser.add_argument("--processed_dir", type=str, default="scratch/processed_grasp_dataset_keyboard", 
                    help="Directory with processed episodes")
parser.add_argument("--pred_dir", type=str, default="scratch/pred/3d_keypoint_predictor",
                    help="Directory with predictions")
parser.add_argument("--split", type=str, choices=["train", "val"], default="val",
                    help="Which split to visualize (train or val)")
args = parser.parse_args()

# Configuration
processed_dir = args.processed_dir
pred_dir = args.pred_dir
max_episodes = args.max_episodes
split = args.split

# Create output directory for renders
render_output_dir = os.path.join(pred_dir, "renders")
os.makedirs(render_output_dir, exist_ok=True)

print("=" * 60)
print(f"Loading predictions from {pred_dir}")
print("=" * 60)

# Load predictions
pred_offset_fields_path = os.path.join(pred_dir, f"{split}_offset_fields_pred.pt")
gt_offset_fields_path = os.path.join(pred_dir, f"{split}_offset_fields_gt.pt")
sequence_ids_path = os.path.join(pred_dir, f"{split}_sequence_ids.txt")

pred_offset_fields = torch.load(pred_offset_fields_path)  # (N, 5, 3, H, W) - downsampled
gt_offset_fields = torch.load(gt_offset_fields_path)  # (N, 5, 3, H, W) - downsampled

with open(sequence_ids_path, 'r') as f:
    sequence_ids = [line.strip() for line in f.readlines()]

print(f"Loaded {len(sequence_ids)} predictions")
print(f"Prediction shape: {pred_offset_fields.shape}")
print(f"GT shape: {gt_offset_fields.shape}")

# Limit to max_episodes
sequences = sequence_ids[:max_episodes]
pred_offset_fields = pred_offset_fields[:max_episodes]
gt_offset_fields = gt_offset_fields[:max_episodes]

print(f"\nRendering first {len(sequences)} sequences:")
for i, seq_id in enumerate(sequences):
    print(f"  {i}: {seq_id}")

print("\n" + "=" * 60)
print("Loading episode data and extracting predicted keypoints")
print("=" * 60)

# Create MuJoCo model once
mj_model, robot_config = create_mujoco_model_with_grippers()
mj_data = mujoco.MjData(mj_model)

# Process each episode
for idx, seq_id in enumerate(sequences):
    print(f"\nProcessing {idx+1}/{len(sequences)}: {seq_id}")
    sequence_dir = os.path.join(processed_dir, seq_id)
    
    # Load start image to get H, W
    start_img_path = os.path.join(sequence_dir, "start.png")
    start_image = cv2.imread(start_img_path)
    start_image = cv2.cvtColor(start_image, cv2.COLOR_BGR2RGB)
    if start_image.max() <= 1.0:
        start_image = (start_image * 255).astype(np.uint8)
    H, W = start_image.shape[:2]
    
    # Load robot-aligned pointmap (filtered points only)
    pointmap_path = os.path.join(sequence_dir, "pointmap_start.pt")
    pointmap = torch.load(pointmap_path)
    
    # Load gripper pose
    gripper_pose_path = os.path.join(sequence_dir, "gripper_pose_grasp.npy")
    gripper_pose = np.load(gripper_pose_path)  # 4x4 transformation matrix
    
    # Transform keypoints from gripper local frame to robot frame (GT)
    gripper_rot = gripper_pose[:3, :3]
    gripper_pos = gripper_pose[:3, 3]
    keypoints_robot_gt = (gripper_rot @ KEYPOINTS_LOCAL_M.T).T + gripper_pos.reshape(1, 3)
    
    # Extract predicted keypoints from offset fields
    # Get predicted offset fields for this sample (downsampled to 224x224)
    pred_offsets = pred_offset_fields[idx]  # (5, 3, H_pred, W_pred)
    H_pred, W_pred = pred_offsets.shape[2], pred_offsets.shape[3]
    
    # Load GT offset fields at full resolution to reconstruct robot frame pointmap
    gt_offsets_full_path = os.path.join(sequence_dir, "offset_fields.pt")
    gt_offsets_full = torch.load(gt_offsets_full_path).numpy()  # (5, H, W, 3) in robot frame
    
    # Reconstruct robot frame pointmap from GT offsets and keypoints
    # offset = keypoint - pointmap_point, so pointmap_point = keypoint - offset
    # All keypoints should give the same pointmap, so use the first one
    keypoint_gt = keypoints_robot_gt[0]  # (3,)
    offset_gt = gt_offsets_full[0]  # (H, W, 3)
    points_robot_full = keypoint_gt.reshape(1, 1, 3) - offset_gt  # (H, W, 3) in robot frame
    
    # Upsample predicted offset fields to full resolution
    # Handle both tensor and numpy array cases
    if isinstance(pred_offsets, torch.Tensor):
        pred_offsets_t = pred_offsets.float()  # (5, 3, H_pred, W_pred)
    else:
        pred_offsets_t = torch.from_numpy(pred_offsets).float()  # (5, 3, H_pred, W_pred)
    
    # Reshape to (15, H_pred, W_pred) for interpolation: 5 keypoints * 3 channels
    pred_offsets_flat = pred_offsets_t.reshape(15, H_pred, W_pred)  # (15, H_pred, W_pred)
    
    # Interpolate to full resolution
    pred_offsets_upsampled = F.interpolate(
        pred_offsets_flat.unsqueeze(0),  # (1, 15, H_pred, W_pred)
        size=(H, W),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)  # (15, H, W)
    
    # Reshape back to (5, 3, H, W) then permute to (5, H, W, 3)
    pred_offsets_upsampled = pred_offsets_upsampled.reshape(5, 3, H, W).permute(0, 2, 3, 1).numpy()  # (5, H, W, 3)
    
    # Extract keypoints: add pointmap + offset, then find pixel with smallest offset magnitude
    keypoints_robot_pred = np.zeros((5, 3), dtype=np.float32)
    for kp_idx in range(5):
        # Compute keypoint locations: keypoint = pointmap + offset
        keypoint_locations = points_robot_full + pred_offsets_upsampled[kp_idx]  # (H, W, 3)
        
        # Find pixel with smallest offset magnitude
        offset_magnitude = np.linalg.norm(pred_offsets_upsampled[kp_idx], axis=2)  # (H, W)
        min_idx = np.unravel_index(np.argmin(offset_magnitude), offset_magnitude.shape)
        keypoints_robot_pred[kp_idx] = keypoint_locations[min_idx]
    
    # Compute Procrustes transformation from GT keypoints to predicted keypoints
    # This aligns GT keypoints to predicted keypoints, giving us the transformation
    # to apply to the GT gripper to match the predicted keypoint locations
    T_procrustes, scale, rotation, translation = procrustes_alignment(
        keypoints_robot_pred, keypoints_robot_gt  # Align GT to predicted
    )
    
    # Apply transformation to gripper pose to get predicted gripper pose
    # T_procrustes transforms GT keypoints to predicted keypoints, so apply it to gripper
    gripper_pose_pred = T_procrustes @ gripper_pose
    
    # Load robot joint states
    joint_states_path = os.path.join(sequence_dir, "joint_states_grasp.npy")
    joint_states = np.load(joint_states_path)
    
    # Set robot state in MuJoCo
    mj_data.qpos[:] = joint_states
    mj_data.ctrl[:] = joint_states[:len(mj_data.ctrl)]
    mujoco.mj_forward(mj_model, mj_data)
    
    # Position exoskeleton meshes
    link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
    position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
    mujoco.mj_forward(mj_model, mj_data)
    
    # Position GT gripper
    gt_pos = gripper_pose[:3, 3]
    gt_rot = R.from_matrix(gripper_pose[:3, :3])
    gt_quat_xyzw = gt_rot.as_quat()
    gt_quat_wxyz = np.array([gt_quat_xyzw[3], gt_quat_xyzw[0], gt_quat_xyzw[1], gt_quat_xyzw[2]])
    gt_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "gripper_gt")
    gt_mocap_id = mj_model.body_mocapid[gt_body_id]
    mj_data.mocap_pos[gt_mocap_id] = gt_pos
    mj_data.mocap_quat[gt_mocap_id] = gt_quat_wxyz
    
    # Position predicted gripper
    pred_pos = gripper_pose_pred[:3, 3]
    pred_rot = R.from_matrix(gripper_pose_pred[:3, :3])
    pred_quat_xyzw = pred_rot.as_quat()
    pred_quat_wxyz = np.array([pred_quat_xyzw[3], pred_quat_xyzw[0], pred_quat_xyzw[1], pred_quat_xyzw[2]])
    pred_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "gripper_pred")
    pred_mocap_id = mj_model.body_mocapid[pred_body_id]
    mj_data.mocap_pos[pred_mocap_id] = pred_pos
    mj_data.mocap_quat[pred_mocap_id] = pred_quat_wxyz
    
    mujoco.mj_forward(mj_model, mj_data)
    
    # Load or detect camera pose
    camera_pose_path = os.path.join(sequence_dir, "robot_camera_pose.npy")
    if os.path.exists(camera_pose_path):
        camera_pose_world = np.load(camera_pose_path)
        # Still need to detect for cam_K
        _, _, cam_K, _, _, _ = detect_and_set_link_poses(start_image, mj_model, mj_data, robot_config)
    else:
        _, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(start_image, mj_model, mj_data, robot_config)
    
    # Render scene from camera pose
    rendered = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, H, W)
    
    # Create visualization
    overlay = (start_image.astype(float) * 0.5 + rendered.astype(float) * 0.5).astype(np.uint8)
    
    # Display with matplotlib
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(start_image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(rendered)
    axes[1].set_title('Simulation Render (GT blue, Pred orange)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    
    # Save visualization
    output_path = os.path.join(render_output_dir, f"{split}_{seq_id}.png")
    plt.savefig(output_path, dpi=100, bbox_inches='tight')
    print(f"  ✓ Saved render to {output_path}")
    
    # Show first image
    if idx == 0:
        plt.show()
    else:
        plt.close()

print(f"\n✓ Done! Renders saved to {render_output_dir}")

