"""Visualize 3D keypoint predictions from a raw RGB image in MuJoCo."""
import sys
import os
from pathlib import Path
import cv2
import mujoco
import numpy as np
import xml.etree.ElementTree as ET
import torch
import argparse
import time
from scipy.spatial.transform import Rotation as R
from PIL import Image
import torchvision.transforms.functional as TF
import pickle

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, get_link_poses_from_robot, render_from_camera_pose
from model import TokenSelectionPredictor
from utils import project_3d_to_2d, rescale_coords, post_process_predictions, ik_to_keypoint_and_rotation, build_patch_positions
from vis_utils import visualize_evaluation_full
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
import mink

WINDOW_SIZE = 10

# Hardcoded median rotation computed from entire dataset (474 gripper poses across 10 episodes)
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

# DINO configuration
REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12

def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))

parser = argparse.ArgumentParser(description="Visualize 3D predictions from raw RGB image in MuJoCo")
parser.add_argument("--image_path", type=str, default="scratch/parsed_propercup_train/episode_001/000000.png", help="Path to RGB image")
parser.add_argument("--render", "-r", action="store_true", help="Render")
parser.add_argument("--vis_2d", action="store_true", help="Show 2D evaluation visualization")
args = parser.parse_args()

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Hardcode original image resolution (before downsampling)
H_orig = 1080
W_orig = 1920

# Load RGB image
rgb_np = cv2.cvtColor(cv2.imread(args.image_path), cv2.COLOR_BGR2RGB)
if rgb_np.max() <= 1.0:
    rgb_np = (rgb_np * 255).astype(np.uint8)
print(f"Loaded image: {args.image_path} (using hardcoded resolution {H_orig}x{W_orig})")

# Setup robot for camera pose estimation
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Estimate camera pose and intrinsics from image
link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(rgb_np, mj_model, mj_data, robot_config)
camera_pose = camera_pose_world
print("✓ Estimated camera pose and intrinsics from image")

# Estimate robot state from image
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# Get current gripper pose to compute current keypoint
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
gripper_pos = mj_data.xpos[gripper_body_id].copy()
gripper_quat = mj_data.xquat[gripper_body_id].copy()
gripper_rot = R.from_quat(gripper_quat[[1, 2, 3, 0]]).as_matrix()
current_kp_3d = gripper_rot @ kp_local + gripper_pos
current_kp_2d = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
if current_kp_2d is None:
    print("Failed to project current KP.")
    exit(1)

# Compute DINO features
print("Computing DINO features...")
sys.path.append(REPO_DIR)
dinov3_model = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
dinov3_model.eval()

# Load PCA embedder
pca_path = Path("scratch/dino_pca_embedder.pkl")
if not pca_path.exists():
    print(f"PCA embedder not found at {pca_path}")
    exit(1)
with open(pca_path, 'rb') as f:
    pca_data = pickle.load(f)
    pca_embedder = pca_data['pca']

# Process image for DINO
img_pil = Image.fromarray(rgb_np).convert("RGB")
image_resized = resize_transform(img_pil)
image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)

# Extract DINO features
with torch.no_grad():
    with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
        feats = dinov3_model.get_intermediate_layers(
            image_resized_norm.unsqueeze(0).to(device),
            n=range(N_LAYERS),
            reshape=True,
            norm=True
        )
        x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
        dim = x.shape[0]
        x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
        
        # Apply PCA
        pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
        
        # Get patch resolution
        h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
        pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)

dino_tokens = torch.from_numpy(pca_features_patches.reshape(-1, 32)).float()  # (num_patches, 32)
num_patches = dino_tokens.shape[0]
patch_positions_np, H_patches, W_patches = build_patch_positions(num_patches, H_patches=h_patches, W_patches=w_patches)
patch_positions = torch.from_numpy(patch_positions_np).float()

# Rescale current keypoint to patch coordinates
current_kp_patches = rescale_coords(current_kp_2d.reshape(1, 2), H_orig, W_orig, H_patches, W_patches)
if current_kp_patches.ndim == 1:
    current_kp_patches = current_kp_patches
else:
    current_kp_patches = current_kp_patches[0]
current_eef_pos = torch.from_numpy(current_kp_patches).float()

print(f"DINO features: {dino_tokens.shape}, Patch resolution: {H_patches}x{W_patches}")

# Load model and run inference
model = TokenSelectionPredictor(dino_feat_dim=32, window_size=WINDOW_SIZE, num_layers=3, num_heads=4, hidden_dim=128, num_pos_bands=4).to(device)
model_path = Path("clean_token_selection_keypoints/tmpstorage/model.pt")
if model_path.exists():
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"✓ Loaded model from {model_path}")
else:
    print(f"⚠ Model not found at {model_path}, using random weights")
model.eval()

with torch.no_grad():
    dino_b = dino_tokens.unsqueeze(0).to(device)
    patch_b = patch_positions.unsqueeze(0).to(device)
    current_b = current_eef_pos.unsqueeze(0).to(device)
    pixel_scores, heights_pred = model(dino_b, patch_b, current_b)
    pixel_scores = pixel_scores.squeeze(0).cpu().numpy()
    heights_pred = heights_pred.squeeze(0).cpu().numpy()

trajectory_pred_3d, pred_image_coords, heights_pred_denorm = post_process_predictions(
    pixel_scores, heights_pred, H_patches, W_patches, H_orig, W_orig, camera_pose, cam_K
)

print(f"Predicted {len(trajectory_pred_3d)} trajectory points")

# 2D visualization if requested
if args.vis_2d:
    import matplotlib.pyplot as plt
    RES_LOW = 224
    
    # Create DINO vis (first 3 channels, normalized)
    dino_vis = pca_features_patches[:, :, :3].copy()
    for i in range(3):
        channel = dino_vis[:, :, i]
        min_val, max_val = channel.min(), channel.max()
        if max_val > min_val:
            dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
        else:
            dino_vis[:, :, i] = 0.5
    dino_vis = np.clip(dino_vis, 0, 1)
    
    # Rescale RGB to low-res
    rgb_lowres = cv2.resize(rgb_np, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
    
    # Convert trajectories to low-res and patch coordinates
    predicted_trajectory_lowres = rescale_coords(pred_image_coords, H_orig, W_orig, RES_LOW, RES_LOW)
    predicted_trajectory_patches = rescale_coords(pred_image_coords, H_orig, W_orig, H_patches, W_patches)
    current_kp_2d_lowres = rescale_coords(current_kp_2d.reshape(1, 2), H_orig, W_orig, RES_LOW, RES_LOW)[0]
    current_kp_2d_patches = current_kp_patches
    
    # Create visualization (no GT trajectories)
    image_name = Path(args.image_path).stem
    save_path = f"token_selection_keypoints/vis_raw_2d_{image_name}.png"
    fig = visualize_evaluation_full(
        rgb_lowres, dino_vis,
        None, None,  # No GT trajectories
        predicted_trajectory_lowres, predicted_trajectory_patches,
        current_kp_2d_lowres, current_kp_2d_patches,
        pixel_scores, H_patches, W_patches,
        heights_pred, heights_gt=None,
        episode_id=image_name, start_idx=0, window_size=WINDOW_SIZE,
        save_path=save_path
    )
    plt.show()

# Setup robot and add predicted trajectory sites to XML
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')

# Add predicted trajectory sites (green-based gradient)
for i, kp_pos in enumerate(trajectory_pred_3d):
    green = 1.0 - (i / max(len(trajectory_pred_3d) - 1, 1)) * 0.5
    red = i / max(len(trajectory_pred_3d) - 1, 1) * 0.5
    ET.SubElement(worldbody, 'site', {
        'name': f'pred_kp_{i}', 'type': 'sphere', 'size': '0.015',
        'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
    })

mj_model = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data = mujoco.MjData(mj_model)

# Re-estimate robot state with new model
link_poses, _, _, _, _, _ = detect_and_set_link_poses(rgb_np, mj_model, mj_data, robot_config)
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model)
ik_configuration.update(mj_data.qpos)

# Animate through predicted trajectory
animating_traj = trajectory_pred_3d

if not args.render:
    viewer = mujoco.viewer.launch_passive(mj_model, mj_data, show_left_ui=False, show_right_ui=False)
    print(f"Predicted pts: {len(trajectory_pred_3d)}")
    print("✓ Animating robot through predicted trajectory (forward and backward loop). Close viewer to exit.")

    while viewer.is_running():
        # Forward pass
        for i, target_kp_pos in enumerate(animating_traj):
            if not viewer.is_running():
                break
            target_gripper_rot = median_dataset_rotation  # Use global median rotation
            ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
            link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            viewer.sync()
            time.sleep(0.1)
        
        # Backward pass
        for i, target_kp_pos in enumerate(reversed(animating_traj)):
            if not viewer.is_running():
                break
            target_gripper_rot = median_dataset_rotation  # Use global median rotation
            ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
            link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            viewer.sync()
            time.sleep(0.1)
else:
    import matplotlib.pyplot as plt
    # render
    for i, target_kp_pos in enumerate(animating_traj[:]): 
        target_gripper_rot = median_dataset_rotation  # Use global median rotation
        ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data)
        
        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose, cam_K, H_orig, W_orig)
        
        # Display results: original RGB, rendered, overlay
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb_np, rendered, (rgb_np * 0.5 + rendered * 0.5).astype(np.uint8)]): 
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Original RGB')
        axes[1].set_title(f'Rendered (t+{i+1})')
        axes[2].set_title('Overlay')
        plt.tight_layout()
        plt.show()
