"""Visualize 3D keypoint predictions vs GT in MuJoCo."""
import sys
import os
from pathlib import Path
import cv2
import mujoco
import numpy as np
import xml.etree.ElementTree as ET
import torch
import argparse
import time
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, get_link_poses_from_robot
from model import TokenSelectionPredictor
from utils import project_3d_to_2d, rescale_coords, post_process_predictions, ik_to_keypoint_and_rotation, load_gt_trajectory_3d, load_dino_features, build_patch_positions, load_cam_data
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
import mink

WINDOW_SIZE = 10

# Hardcoded median rotation computed from entire dataset (474 gripper poses across 10 episodes)
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

parser = argparse.ArgumentParser(description="Visualize 3D predictions in MuJoCo")
parser.add_argument("--dataset_dir", "-d", default="scratch/parsed_propercup_train", type=str, help="Dataset directory")
parser.add_argument("--episode_idx", default=0, type=int, help="Episode index")
parser.add_argument("--start_frame", "--sf", default=0, type=int, help="Start frame")
parser.add_argument("--render", "-r", action="store_true", help="Render")
parser.add_argument("--load_past_pred", "-lp", action="store_true", help="Load past predictions")
parser.add_argument("--use_median_rotation", action="store_true", help="Use median rotation from trajectory instead of per-timestep GT rotation")
parser.add_argument("--use_global_median_rotation", action="store_true", help="Use global median rotation from entire dataset instead of per-timestep GT rotation")
args = parser.parse_args()

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")


dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
if len(episode_dirs) == 0:
    print("No episodes found.")
    exit(1)
if args.episode_idx >= len(episode_dirs):
    args.episode_idx = 0
episode_dir = episode_dirs[args.episode_idx]
episode_id = episode_dir.name
print(f"Episode: {episode_id}")

frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
if len(frame_files) < WINDOW_SIZE + 1:
    print("Not enough frames.")
    exit(1)

start_idx = args.start_frame
if start_idx < 0 or start_idx >= len(frame_files) - WINDOW_SIZE:
    start_idx = 0

camera_pose, cam_K = load_cam_data(episode_dir, frame_files[start_idx])
if camera_pose is None or cam_K is None:
    print("Camera data missing.")
    exit(1)

# Hardcode original image resolution (before downsampling)
H_orig = 1080
W_orig = 1920

# cam_K is already calibrated at the original resolution (1080x1920)
# No scaling needed - use it directly with H_orig=1080, W_orig=1920

start_frame_file = frame_files[start_idx]
rgb_np = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
if rgb_np.max() <= 1.0:
    rgb_np = (rgb_np * 255).astype(np.uint8)

start_frame_str = f"{int(start_frame_file.stem):06d}"
dino_path = episode_dir / f"dino_features_{start_frame_str}.pt"
if not dino_path.exists():
    print(f"Missing DINO features: {dino_path}")
    exit(1)
dino_tokens, H_patches_loaded, W_patches_loaded = load_dino_features(dino_path)
num_patches = dino_tokens.shape[0]
patch_positions_np, H_patches, W_patches = build_patch_positions(num_patches, H_patches=H_patches_loaded, W_patches=W_patches_loaded)
patch_positions = torch.from_numpy(patch_positions_np).float()

current_pose_path = episode_dir / f"{start_frame_str}_gripper_pose.npy"
if not current_pose_path.exists():
    print("Missing start gripper pose.")
    exit(1)
current_pose = np.load(current_pose_path)
current_rot = current_pose[:3, :3]
current_pos = current_pose[:3, 3]
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
current_kp_3d = current_rot @ kp_local + current_pos
current_kp_2d = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
if current_kp_2d is None:
    print("Failed to project current KP.")
    exit(1)
current_kp_patches = rescale_coords(current_kp_2d.reshape(1, 2), H_orig, W_orig, H_patches, W_patches)
if current_kp_patches.ndim == 1:
    current_kp_patches = current_kp_patches
else:
    current_kp_patches = current_kp_patches[0]
current_eef_pos = torch.from_numpy(current_kp_patches).float()

# GT 3D trajectory and orientations
trajectory_gt_3d, orientations_gt = load_gt_trajectory_3d(episode_dir, frame_files, start_idx, WINDOW_SIZE, kp_local, return_orientations=True)

# Load GT gripper poses (SE3)
gripper_poses_gt = []
for offset in range(1, WINDOW_SIZE + 1):
    f_idx = start_idx + offset
    if f_idx >= len(frame_files):
        break
    frame_str = f"{int(frame_files[f_idx].stem):06d}"
    pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
    if not pose_path.exists():
        continue
    gripper_pose = np.load(pose_path)  # 4x4 SE3 matrix
    gripper_poses_gt.append(gripper_pose)
gripper_poses_gt = np.array(gripper_poses_gt) if len(gripper_poses_gt) > 0 else np.array([]).reshape(0, 4, 4)
print(f"Loaded {len(gripper_poses_gt)} GT gripper poses")

# Compute median/mean rotation if requested
if args.use_median_rotation and len(gripper_poses_gt) > 0:
    # Compute mean rotation by averaging quaternions
    quats = []
    for pose in gripper_poses_gt:
        rot = pose[:3, :3]
        quat = R.from_matrix(rot).as_quat()  # xyzw format
        quats.append(quat)
    quats = np.array(quats)
    # Average quaternions (simple mean, then normalize)
    mean_quat = quats.mean(axis=0)
    mean_quat = mean_quat / np.linalg.norm(mean_quat)
    median_gripper_rot = R.from_quat(mean_quat).as_matrix()
    print(f"Using mean rotation computed from {len(gripper_poses_gt)} timesteps")
else:
    median_gripper_rot = None

# Inference
with torch.no_grad():
    dino_b = dino_tokens.unsqueeze(0).to(device)
    patch_b = patch_positions.unsqueeze(0).to(device)
    current_b = current_eef_pos.unsqueeze(0).to(device)

    tmppath = Path("clean_token_selection_keypoints/test_scripts/tmp.pt")   
    if tmppath.exists() and args.load_past_pred:
        pixel_scores, heights_pred = torch.load(tmppath, map_location=device)
    else:
        model = TokenSelectionPredictor(dino_feat_dim=32, window_size=WINDOW_SIZE, num_layers=3, num_heads=4, hidden_dim=128, num_pos_bands=4).to(device)
        model_path = Path("clean_token_selection_keypoints/tmpstorage/model.pt")
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"✓ Loaded model from {model_path}")
        model.eval()
        pixel_scores, heights_pred = model(dino_b, patch_b, current_b)
        torch.save((pixel_scores, heights_pred), tmppath)

    pixel_scores = pixel_scores.squeeze(0).cpu().numpy()
    heights_pred = heights_pred.squeeze(0).cpu().numpy()

trajectory_pred_3d, pred_image_coords, heights_pred_denorm = post_process_predictions(
    pixel_scores, heights_pred, H_patches, W_patches, H_orig, W_orig, camera_pose, cam_K
)

# Setup robot and add trajectory sites to XML
robot_config = SO100AdhesiveConfig()
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')

# Add GT trajectory sites (orange to red gradient)
for i, kp_pos in enumerate(trajectory_gt_3d):
    progress = i / max(len(trajectory_gt_3d) - 1, 1)
    red = 1.0
    green = 0.5 * (1 - progress)  # Goes from 0.5 (orange) to 0 (red)
    blue = 0.0
    ET.SubElement(worldbody, 'site', {
        'name': f'gt_kp_{i}', 'type': 'sphere', 'size': '0.015',
        'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} {blue} 0.8'
    })

# Add predicted trajectory sites (green-based gradient)
for i, kp_pos in enumerate(trajectory_pred_3d):
    green = 1.0 - (i / max(len(trajectory_pred_3d) - 1, 1)) * 0.5  # Green to yellow gradient
    red = i / max(len(trajectory_pred_3d) - 1, 1) * 0.5
    ET.SubElement(worldbody, 'site', {
        'name': f'pred_kp_{i}', 'type': 'sphere', 'size': '0.015',
        'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
    })

mj_model = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data = mujoco.MjData(mj_model)

# Load start image and estimate robot state
link_poses, _, _, _, _, _ = detect_and_set_link_poses(rgb_np, mj_model, mj_data, robot_config)
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model)
ik_configuration.update(mj_data.qpos)

# View in MuJoCo and animate through GT trajectory
animating_traj=trajectory_pred_3d  # Use predicted keypoints
animating_gripper_poses=gripper_poses_gt  # Use GT rotations (cheating for now)

if not args.render:
    viewer = mujoco.viewer.launch_passive(mj_model, mj_data, show_left_ui=False, show_right_ui=False)
    print(f"Episode: {episode_id} | Start Frame: {start_idx} | GT pts: {len(trajectory_gt_3d)} | Pred pts: {len(trajectory_pred_3d)}")
    print("✓ Animating robot through GT trajectory (forward and backward loop). Close viewer to exit.")


    # Animate through GT trajectory forward and backward in a loop
    while viewer.is_running():
        # Forward pass
        for i, target_kp_pos in enumerate(animating_traj):
            if not viewer.is_running():
                break
            # Use separate tasks for keypoint position and gripper rotation
            if animating_gripper_poses is not None and i < len(animating_gripper_poses) and i < len(animating_traj):
                if args.use_global_median_rotation:
                    target_gripper_rot = median_dataset_rotation  # Use global median rotation from dataset
                elif args.use_median_rotation and median_gripper_rot is not None:
                    target_gripper_rot = median_gripper_rot  # Use median rotation from this trajectory
                else:
                    target_gripper_rot = animating_gripper_poses[i][:3, :3]  # GT rotation per timestep
                ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
            link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            viewer.sync()
            time.sleep(0.1)  # 100ms delay between trajectory points
        
        # Backward pass
        for i, target_kp_pos in enumerate(reversed(animating_traj)):
            if not viewer.is_running():
                break
            rev_idx = len(animating_traj) - 1 - i
            # Use separate tasks for keypoint position and gripper rotation
            if animating_gripper_poses is not None and rev_idx < len(animating_gripper_poses) and rev_idx < len(animating_traj):
                if args.use_global_median_rotation:
                    target_gripper_rot = median_dataset_rotation  # Use global median rotation from dataset
                elif args.use_median_rotation and median_gripper_rot is not None:
                    target_gripper_rot = median_gripper_rot  # Use median rotation from this trajectory
                else:
                    target_gripper_rot = animating_gripper_poses[rev_idx][:3, :3]  # GT rotation per timestep
                ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
            link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            viewer.sync()
            time.sleep(0.1)  # 100ms delay between trajectory points
else:
    import matplotlib.pyplot as plt
    from exo_utils import render_from_camera_pose
    # render
    for i, target_kp_pos in enumerate(animating_traj[:]): 
        # Use separate tasks for keypoint position and gripper rotation
        if args.use_global_median_rotation:
            target_gripper_rot = median_dataset_rotation  # Use global median rotation from dataset
        elif args.use_median_rotation and median_gripper_rot is not None:
            target_gripper_rot = median_gripper_rot  # Use median rotation from this trajectory
        else:
            target_gripper_rot = animating_gripper_poses[i][:3, :3]  # GT rotation per timestep
        ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
        # else:
        #     # Commented out: keypoint-based IK
        #     # target_rot = animating_orientations[i] if animating_orientations is not None and i < len(animating_orientations) else None
        #     # ik_to_keypoint(target_pos, ik_configuration, robot_config, mj_model, mj_data, target_rot=target_rot)

        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose, cam_K, H_orig, W_orig)
        
        # Load current timestep's RGB image
        f_idx = start_idx + i + 1
        if f_idx < len(frame_files):
            current_frame_file = frame_files[f_idx]
            rgb_current = cv2.cvtColor(cv2.imread(str(current_frame_file)), cv2.COLOR_BGR2RGB)
            if rgb_current.max() <= 1.0:
                rgb_current = (rgb_current * 255).astype(np.uint8)
        else:
            rgb_current = rgb_np  # Fallback to start frame if out of bounds
        
        # Display results: start RGB, current timestep RGB, rendered, overlay of start+rendered, overlay of current+rendered
        fig, axes = plt.subplots(1, 5, figsize=(25, 5))
        overlay_start = (rgb_np * 0.5 + rendered * 0.5).astype(np.uint8)
        overlay_current = (rgb_current * 0.5 + rendered * 0.5).astype(np.uint8)
        for ax, img in zip(axes, [rgb_np, rgb_current, rendered, overlay_start, overlay_current]): 
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Start Frame RGB')
        axes[1].set_title(f'Current Frame RGB (t+{i+1})')
        axes[2].set_title('Rendered')
        axes[3].set_title('Start + Rendered')
        axes[4].set_title('Current + Rendered')
        plt.tight_layout()
        plt.show()
