"""Simple inference from a single RGB image: compute camera pose, extract DINO features, and predict 3D keypoints."""
import sys
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import functional as TF
import argparse
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.spatial.transform import Rotation as R
import xml.etree.ElementTree as ET
import mink


sys.path.insert(0,"/Users/cameronsmith/Projects/robotics_testing/3dkeygrip")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX, unproject_patch_to_groundplane, compute_volume_mask_for_patches
from utils import project_3d_to_2d, rescale_coords, recover_3d_from_direct_keypoint_and_height, post_process_predictions, ik_to_keypoint_and_rotation
from model import TrajectoryPredictor, MIN_HEIGHT, MAX_HEIGHT, GROUNDPLANE_X_MIN, GROUNDPLANE_X_MAX, GROUNDPLANE_Z_MIN, GROUNDPLANE_Z_MAX
from exo_utils import render_from_camera_pose, get_link_poses_from_robot
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes
import mujoco

from robot_models.so100_controller import Arm
import pickle

# DINO configuration constants
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12

# Model constants
RES_LOW = 256
H_orig = 1080
W_orig = 1920
MAX_TIMESTEPS = 3  # Extrema points: close, open, end
GROUNDPLANE_RANGE = 1.0

def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))

def extract_dino_features(rgb, model_dino, pca_embedder, device):
    """Extract DINO features from RGB image."""
    H, W = rgb.shape[:2]
    
    # Load and preprocess image for DINO
    img_pil = Image.fromarray(rgb).convert("RGB")
    image_resized = resize_transform(img_pil)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract DINO features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = model_dino.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            # Apply PCA to reduce to 32 dimensions
            pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
            
            # Get patch resolution
            h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
            pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)
    
    return torch.from_numpy(pca_features_patches).float(), h_patches, w_patches

parser = argparse.ArgumentParser(description="Infer 3D keypoints from a single RGB image")
parser.add_argument("--image", "-i", type=str, required=True, help="Path to RGB image")
parser.add_argument("--model_path", "-m", type=str, default="keypoint_testing2/tmpstorage/model.pt", help="Path to trained model")
parser.add_argument("--pca_path", type=str, default="scratch/dino_pca_embedder.pkl", help="Path to PCA embedder")
parser.add_argument("--use_arm", action="store_true", help="Use arm in inference")
parser.add_argument("--dont_show_plots", action="store_true", help="Don't show plots")
args = parser.parse_args()

if args.use_arm:
    arm=Arm( pickle.load(open("robot_models/arm_offsets/rescrew_fromimg.pkl", 'rb')) )
    print("✓ Connected to robot for direct joint state reading")

# DINOv3 configuration (hardcoded like add_dino_features.py)
REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"

# Load image
image_path = Path(args.image)
if not image_path.exists():
    print(f"Error: Image not found at {image_path}")
    exit(1)

rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)

H_loaded, W_loaded = rgb.shape[:2]
print(f"Loaded image: {W_loaded}x{H_loaded}")

# Resize RGB to original resolution for consistency
rgb_resized = cv2.resize(rgb, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)

# Setup MuJoCo model
print("Setting up MuJoCo model...")
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Detect ArUco markers to get camera pose and intrinsics
print("Detecting ArUco markers and computing camera pose...")
try:
    link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
        rgb_resized, mj_model, mj_data, robot_config
    )
    position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
    configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=35)
    mj_data.qpos[:] = configuration.q
    mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
    mujoco.mj_forward(mj_model, mj_data)
    print(f"✓ Camera pose computed")
    print(f"  Camera position: {camera_pose_world[:3, 3]}")
    print(f"  Camera intrinsics shape: {cam_K.shape}")
except Exception as e:
    print(f"Error detecting link poses: {e}")
    exit(1)

# Load DINO model and PCA embedder
print("Loading DINO model and PCA embedder...")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load DINOv3 model (hardcoded like add_dino_features.py)
print("Loading DINOv3 model...")
with torch.inference_mode():
    model_dino = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
    model_dino.eval()
print(f"✓ Loaded DINOv3 model from {WEIGHTS_PATH}")

# Load PCA embedder
import pickle
pca_path = Path(args.pca_path)
if not pca_path.exists():
    print(f"Error: PCA embedder not found at {pca_path}")
    exit(1)
print(f"Loading PCA embedder from {pca_path}")
with open(pca_path, 'rb') as f:
    pca_data = pickle.load(f)
    pca_embedder = pca_data['pca']
    print(f"✓ PCA embedder: {pca_embedder.n_components_} dimensions")

# Extract DINO features
print("Extracting DINO features...")
dino_features, H_patches, W_patches = extract_dino_features(rgb, model_dino, pca_embedder, device)
dino_feat_dim = dino_features.shape[2]
print(f"✓ DINO features extracted: {H_patches}x{W_patches} patches, {dino_feat_dim} dims")

# Flatten DINO features for model input
dino_tokens_flat = dino_features.view(H_patches * W_patches, dino_feat_dim).float()

# Compute ground-plane coordinates for each patch
print("Computing groundplane coordinates...")
groundplane_coords_list = []
y_coords, x_coords = np.meshgrid(np.arange(H_patches), np.arange(W_patches), indexing='ij')
for patch_y, patch_x in zip(y_coords.flatten(), x_coords.flatten()):
    x_gp, z_gp = unproject_patch_to_groundplane(
        patch_x, patch_y, H_patches, W_patches, H_orig, W_orig,
        camera_pose_world, cam_K, ground_y=0.0
    )
    if x_gp is not None and z_gp is not None:
        groundplane_coords_list.append([x_gp, z_gp])
    else:
        groundplane_coords_list.append([0.0, 0.0])

groundplane_coords = torch.from_numpy(np.array(groundplane_coords_list, dtype=np.float32))
print(f"✓ Groundplane coordinates computed")

# Compute volume mask for patches
print("Computing volume mask...")
volume_mask = compute_volume_mask_for_patches(H_patches, W_patches, H_orig, W_orig, camera_pose_world, cam_K)
volume_mask_flat = torch.from_numpy(volume_mask.flatten().astype(np.float32))  # (num_patches,)
print(f"✓ Volume mask computed: {volume_mask.sum()} / {volume_mask.size} patches valid")

# Compute current EEF 2D position and height
print("Computing current EEF position...")
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Get current gripper pose from MuJoCo
exo_mesh_body_name = "fixed_gripper_exo_mesh"
exo_mesh_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
exo_mesh_mocap_id = mj_model.body_mocapid[exo_mesh_body_id]
gripper_pos = mj_data.mocap_pos[exo_mesh_mocap_id].copy()
gripper_quat = mj_data.mocap_quat[exo_mesh_mocap_id].copy()

# Convert quaternion to rotation matrix
from scipy.spatial.transform import Rotation as R
gripper_rot = R.from_quat(gripper_quat[[1, 2, 3, 0]]).as_matrix()  # wxyz to xyzw

# Compute current keypoint 3D position
current_kp_3d = gripper_rot @ kp_local + gripper_pos

# Project current EEF to 2D
current_kp_2d_image = project_3d_to_2d(current_kp_3d, camera_pose_world, cam_K)
if current_kp_2d_image is None:
    print("Error: Current EEF projection failed")
    exit(1)

# Rescale to patch coordinates
current_kp_2d_patches = rescale_coords(
    current_kp_2d_image.reshape(1, 2),
    H_orig, W_orig,
    H_patches, W_patches
)[0]

# Find nearest patch for current EEF
current_eef_patch_x = int(np.round(np.clip(current_kp_2d_patches[0], 0, W_patches - 1)))
current_eef_patch_y = int(np.round(np.clip(current_kp_2d_patches[1], 0, H_patches - 1)))
current_eef_patch_idx = current_eef_patch_y * W_patches + current_eef_patch_x

# Get current EEF height (normalized)
current_eef_height_norm = float(np.clip((current_kp_3d[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0))
print(f"✓ Current EEF: patch ({current_eef_patch_x}, {current_eef_patch_y}), height={current_eef_height_norm:.3f}")

# Load model
print(f"Loading model from {args.model_path}...")
model_path = Path(args.model_path)
if not model_path.exists():
    print(f"Error: Model not found at {model_path}")
    exit(1)

# Detect max_timesteps and dino_feat_dim from checkpoint
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
detected_max_timesteps = MAX_TIMESTEPS
if 'height_timestep_emb.weight' in checkpoint:
    detected_max_timesteps = checkpoint['height_timestep_emb.weight'].shape[0]
    print(f"Detected max_timesteps={detected_max_timesteps} from checkpoint")
    if detected_max_timesteps != MAX_TIMESTEPS:
        print(f"⚠ Warning: Checkpoint has max_timesteps={detected_max_timesteps}, but code expects {MAX_TIMESTEPS}")
        print(f"Using checkpoint's max_timesteps={detected_max_timesteps} for model loading")
else:
    print(f"Could not detect max_timesteps from checkpoint, using {MAX_TIMESTEPS}")

if 'dino_proj.weight' in checkpoint:
    detected_dino_feat_dim = checkpoint['dino_proj.weight'].shape[1]
    print(f"Detected dino_feat_dim={detected_dino_feat_dim} from checkpoint")
    dino_feat_dim = detected_dino_feat_dim  # Use detected value
else:
    print(f"Could not detect dino_feat_dim from checkpoint, using {dino_feat_dim}")

model = TrajectoryPredictor(
    dino_feat_dim=dino_feat_dim,
    max_timesteps=detected_max_timesteps,  # Use detected max_timesteps
    num_layers=3,
    num_heads=4,
    hidden_dim=128,
    num_pos_bands=4,
    groundplane_range=GROUNDPLANE_RANGE
).to(device)
model.load_state_dict(checkpoint, strict=True)
model.eval()
print(f"✓ Model loaded")

# Run inference
print("Running inference...")
with torch.no_grad():
    dino_tokens_batch = dino_tokens_flat.unsqueeze(0).to(device)  # (1, num_patches, dino_feat_dim)
    groundplane_coords_batch = groundplane_coords.unsqueeze(0).to(device)  # (1, num_patches, 2)
    current_eef_patch_idx_batch = torch.tensor([current_eef_patch_idx], dtype=torch.long).to(device)
    current_eef_height_batch = torch.tensor([current_eef_height_norm], dtype=torch.float32).to(device)
    
    volume_mask_batch = volume_mask_flat.unsqueeze(0).to(device)  # (1, num_patches)
    
    attention_scores, heights_pred, grippers_pred = model(
        dino_tokens_batch,
        groundplane_coords_batch,
        current_eef_patch_idx_batch,
        current_eef_height_batch,
        volume_mask=volume_mask_batch,  # Pass volume mask
        use_attention_mask=True  # Use volume mask during inference
    )
    
    attention_scores = attention_scores.squeeze(0).cpu().numpy()  # (max_timesteps, num_patches)
    heights_pred = heights_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)
    grippers_pred = grippers_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)

print("✓ Inference complete")

# Post-process predictions to get predicted 2D direct keypoint coordinates and heights
_, pred_image_coords, heights_pred_denorm = post_process_predictions(
    attention_scores, heights_pred, H_patches, W_patches, H_orig, W_orig, camera_pose_world, cam_K
)

print(f"✓ Predicted {len(pred_image_coords)} trajectory points")

# Lift predicted 2D direct keypoint + predicted height to 3D
print("Lifting 2D predictions to 3D...")
trajectory_3d = []
for i in range(len(pred_image_coords)):
    kp_2d = pred_image_coords[i]
    height = heights_pred_denorm[i]
    kp_3d = recover_3d_from_direct_keypoint_and_height(kp_2d, height, camera_pose_world, cam_K)
    if kp_3d is not None:
        trajectory_3d.append(kp_3d)

trajectory_3d = np.array(trajectory_3d)
print(f"✓ Lifted to 3D: {len(trajectory_3d)} keypoints")

# Rescale predicted trajectory to low-res for visualization
rgb_lowres = cv2.resize(rgb, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
predicted_trajectory_lowres = rescale_coords(pred_image_coords, H_orig, W_orig, RES_LOW, RES_LOW)

# Prepare DINO features visualization (first 3 channels, normalized)
dino_vis = dino_features[:, :, :3].numpy()  # (H_patches, W_patches, 3)
for i in range(3):
    channel = dino_vis[:, :, i]
    min_val, max_val = channel.min(), channel.max()
    if max_val > min_val:
        dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
    else:
        dino_vis[:, :, i] = 0.5
dino_vis = np.clip(dino_vis, 0, 1)
# Upscale DINO vis to low-res for display
dino_vis_upscaled = cv2.resize(dino_vis, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Prepare attention map visualization (first timestep)
attention_map = attention_scores[0].reshape(H_patches, W_patches)
# Use volume mask for visualization
volume_mask_2d = volume_mask  # (H_patches, W_patches) boolean
# Replace masked values with minimum value from valid region for better visualization
if np.any(volume_mask_2d):
    min_valid_value = attention_map[volume_mask_2d].min()
    attention_map[~volume_mask_2d] = min_valid_value
attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
attention_map_upscaled = cv2.resize(attention_map_norm, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Create dense groundplane coordinate map (like render_dense_groundplane.py)
print("Computing dense groundplane coordinate map...")
u, v = np.meshgrid(np.arange(W_orig), np.arange(H_orig))
u_flat = u.flatten()
v_flat = v.flatten()

# Convert pixels to normalized camera coordinates (vectorized)
K_inv = np.linalg.inv(cam_K)
pixels_h = np.stack([u_flat, v_flat, np.ones(len(u_flat))], axis=1).T  # (3, N)
rays_cam = (K_inv @ pixels_h).T  # (N, 3)

# Transform rays to world coordinates
cam_pose_inv = np.linalg.inv(camera_pose_world)
rays_world = (cam_pose_inv[:3, :3] @ rays_cam.T).T  # (N, 3) - direction vectors
cam_pos = cam_pose_inv[:3, 3]  # (3,)

# Find intersection with ground plane (Y = 0, where Y is index 2)
ray_dir_z = rays_world[:, 2]  # Z component of ray direction (Y is index 2)
t = (0.0 - cam_pos[2]) / (ray_dir_z + 1e-10)  # Avoid division by zero

# Valid intersections: t > 0 and not parallel to plane
valid_mask_flat = (t > 0) & (np.abs(ray_dir_z) > 1e-6)

# Compute 3D points on ground plane
points_3d_flat = cam_pos[None, :] + t[:, None] * rays_world  # (N, 3)

# Reshape to image dimensions
points_3d = points_3d_flat.reshape(H_orig, W_orig, 3)
valid_mask = valid_mask_flat.reshape(H_orig, W_orig)

# Extract X and Z coordinates
x_coords = points_3d[:, :, 0]  # X coordinate
z_coords = points_3d[:, :, 1]  # Z coordinate

# Normalize coordinates to [0, 1] for color mapping
grid_range = 1.0
x_norm = np.clip((x_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)
z_norm = np.clip((z_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)

# Create RGB colors based on XZ coordinates
red_channel = 1.0 - x_norm * 0.5
green_channel_x = x_norm
blue_channel = 1.0 - z_norm
green_channel_z = z_norm

coord_colors = np.stack([
    red_channel,
    np.clip(green_channel_x + green_channel_z, 0, 1),
    blue_channel
], axis=2)  # (H_orig, W_orig, 3)

# Set invalid pixels to white
coord_colors[~valid_mask] = 1.0

# Resize coord_colors to low-res for visualization
coord_colors_lowres = cv2.resize(coord_colors, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Generate groundplane grid for line plot visualization (like render_groundplane.py)
grid_spacing = 0.1
grid_range = 1.0
grid_x = np.arange(-grid_range, grid_range + grid_spacing, grid_spacing)
grid_z = np.arange(-grid_range, grid_range + grid_spacing, grid_spacing)
grid_xx, grid_zz = np.meshgrid(grid_x, grid_z)
# Stack as [x, z, y=0] where y is at index 2 (last dimension)
grid_points_3d = np.stack([grid_xx.flatten(), grid_zz.flatten(), np.zeros_like(grid_xx.flatten())], axis=1)

# Project grid points to 2D
num_x = len(grid_x)
num_z = len(grid_z)
grid_2d_reshaped = np.full((num_z, num_x, 2), np.nan)

# Ensure Y=0 for all points
grid_points_3d_flat = grid_points_3d.copy()
grid_points_3d_flat[:, 2] = 0.0  # Set Y=0 (index 2, last dimension)

# Project all points at once (vectorized)
points_3d_h = np.column_stack([grid_points_3d_flat, np.ones(len(grid_points_3d_flat))])
points_cam = (camera_pose_world @ points_3d_h.T).T[:, :3]  # (N, 3)

# Filter points in front of camera
valid_mask = points_cam[:, 2] > 0

# Project to 2D for valid points
points_2d_h = (cam_K @ points_cam[valid_mask].T).T  # (N_valid, 3)
points_2d = points_2d_h[:, :2] / points_2d_h[:, 2:3]  # (N_valid, 2)

# Fill in valid projections
valid_indices = np.where(valid_mask)[0]
for idx, point_2d_val in zip(valid_indices, points_2d):
    z_idx = idx // num_x
    x_idx = idx % num_x
    grid_2d_reshaped[z_idx, x_idx, :] = point_2d_val

# Create 2D visualization
print("Creating 2D visualization...")
fig = plt.figure(figsize=(36, 14))
gs = GridSpec(3, 6, figure=fig, hspace=0.3, wspace=0.3, height_ratios=[1, 0.4, 0.4])

# Row 1: RGB with trajectory, Attention map, Groundplane trajectory overlay, DINO features, Groundplane UV colors, Groundplane grid lines
ax_rgb = fig.add_subplot(gs[0, 0])
ax_attention = fig.add_subplot(gs[0, 1])
ax_overlay = fig.add_subplot(gs[0, 2])
ax_dino = fig.add_subplot(gs[0, 3])
ax_groundplane = fig.add_subplot(gs[0, 4])
ax_grid = fig.add_subplot(gs[0, 5])
ax_height = fig.add_subplot(gs[1, :])  # Height chart spans all columns
ax_gripper = fig.add_subplot(gs[2, :])  # Gripper chart spans all columns

# RGB with predicted groundplane trajectory
ax_rgb.imshow(rgb_lowres)
if len(predicted_trajectory_lowres) > 0:
    ax_rgb.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
               'r-', linewidth=2, alpha=0.8, label='Pred Direct Keypoint Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_rgb.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)
ax_rgb.set_title("RGB with Predicted Groundplane Trajectory", fontsize=12)
ax_rgb.legend(loc='upper right', fontsize=9)
ax_rgb.axis('off')

# Attention map (first timestep)
im_attn = ax_attention.imshow(attention_map_upscaled, cmap='hot', vmin=0, vmax=1)
if len(predicted_trajectory_lowres) > 0:
    ax_attention.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
                     'w-', linewidth=2, alpha=0.8, label='Pred Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_attention.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)
ax_attention.set_title("Attention Map (t=1) with Trajectory", fontsize=12)
ax_attention.legend(loc='upper right', fontsize=9)
ax_attention.axis('off')

# Overlay: RGB + attention map + trajectory
rgb_normalized = rgb_lowres.astype(np.float32) / 255.0
# Convert attention map to RGB for overlay
attention_rgb = plt.cm.hot(attention_map_upscaled)[:, :, :3]  # (H, W, 3)
overlay = 0.5 * rgb_normalized + 0.5 * attention_rgb
ax_overlay.imshow(overlay)
if len(predicted_trajectory_lowres) > 0:
    ax_overlay.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
                   'r-', linewidth=3, alpha=0.9, label='Pred Direct Keypoint Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_overlay.plot(x, y, 'x', color=color, markersize=8, markeredgewidth=2, zorder=11)
ax_overlay.set_title("Overlay: RGB + Attention + Trajectory", fontsize=12)
ax_overlay.legend(loc='upper right', fontsize=9)
ax_overlay.axis('off')

# DINO features visualization
ax_dino.imshow(dino_vis_upscaled)
if len(predicted_trajectory_lowres) > 0:
    ax_dino.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
               'w-', linewidth=2, alpha=0.8, label='Pred Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_dino.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)
ax_dino.set_title("DINO Features with Trajectory", fontsize=12)
ax_dino.legend(loc='upper right', fontsize=9)
ax_dino.axis('off')

# Dense groundplane coordinate map (UV colors)
ax_groundplane.imshow(coord_colors_lowres)
if len(predicted_trajectory_lowres) > 0:
    ax_groundplane.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
               'w-', linewidth=2, alpha=0.8, label='Pred Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_groundplane.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)
ax_groundplane.set_title("Groundplane UV Colors with Trajectory", fontsize=12)
ax_groundplane.legend(loc='upper right', fontsize=9)
ax_groundplane.axis('off')

# Groundplane grid lines visualization (like render_groundplane.py)
ax_grid.imshow(rgb_lowres)

# Draw grid lines (horizontal lines - constant Z)
for z_idx in range(num_z):
    line_points = grid_2d_reshaped[z_idx, :, :]
    valid_mask = ~np.isnan(line_points[:, 0])
    if np.sum(valid_mask) > 1:
        line_points_valid = line_points[valid_mask]
        # Rescale to low-res coordinates
        line_points_lowres = rescale_coords(line_points_valid, H_orig, W_orig, RES_LOW, RES_LOW)
        # Filter to image bounds
        in_bounds = (line_points_lowres[:, 0] >= 0) & (line_points_lowres[:, 0] < RES_LOW) & \
                   (line_points_lowres[:, 1] >= 0) & (line_points_lowres[:, 1] < RES_LOW)
        if np.sum(in_bounds) > 1:
            line_points_final = line_points_lowres[in_bounds]
            ax_grid.plot(line_points_final[:, 0], line_points_final[:, 1], 
                       'b-', linewidth=0.5, alpha=0.6)

# Draw grid lines (vertical lines - constant X)
for x_idx in range(num_x):
    line_points = grid_2d_reshaped[:, x_idx, :]
    valid_mask = ~np.isnan(line_points[:, 0])
    if np.sum(valid_mask) > 1:
        line_points_valid = line_points[valid_mask]
        # Rescale to low-res coordinates
        line_points_lowres = rescale_coords(line_points_valid, H_orig, W_orig, RES_LOW, RES_LOW)
        # Filter to image bounds
        in_bounds = (line_points_lowres[:, 0] >= 0) & (line_points_lowres[:, 0] < RES_LOW) & \
                   (line_points_lowres[:, 1] >= 0) & (line_points_lowres[:, 1] < RES_LOW)
        if np.sum(in_bounds) > 1:
            line_points_final = line_points_lowres[in_bounds]
            ax_grid.plot(line_points_final[:, 0], line_points_final[:, 1], 
                       'b-', linewidth=0.5, alpha=0.6)

# Draw grid points colored by 3D coordinates
all_points_2d = grid_2d_reshaped.reshape(-1, 2)
all_points_3d_flat = grid_points_3d_flat

# Get valid 2D points and their corresponding 3D coordinates
valid_2d_mask = ~np.isnan(all_points_2d[:, 0])
valid_points_2d = all_points_2d[valid_2d_mask]
valid_points_3d = all_points_3d_flat[valid_2d_mask]

if len(valid_points_2d) > 0:
    # Filter to image bounds (original resolution)
    in_bounds = (valid_points_2d[:, 0] >= 0) & (valid_points_2d[:, 0] < W_orig) & \
               (valid_points_2d[:, 1] >= 0) & (valid_points_2d[:, 1] < H_orig)
    valid_points_2d_final = valid_points_2d[in_bounds]
    valid_points_3d_final = valid_points_3d[in_bounds]
    
    if len(valid_points_2d_final) > 0:
        # Rescale to low-res for display
        valid_points_2d_lowres = rescale_coords(valid_points_2d_final, H_orig, W_orig, RES_LOW, RES_LOW)
        
        # Extract X and Z coordinates
        x_coords = valid_points_3d_final[:, 0]  # X coordinate
        z_coords = valid_points_3d_final[:, 1]  # Z coordinate
        
        # Normalize coordinates to [0, 1] for color mapping
        x_min, x_max = x_coords.min(), x_coords.max()
        z_min, z_max = z_coords.min(), z_coords.max()
        
        x_norm = (x_coords - x_min) / (x_max - x_min + 1e-6)
        z_norm = (z_coords - z_min) / (z_max - z_min + 1e-6)
        
        # Create RGB colors based on XZ coordinates
        red_channel = 1.0 - x_norm * 0.5
        green_channel_x = x_norm
        blue_channel = 1.0 - z_norm
        green_channel_z = z_norm
        
        colors = np.stack([
            red_channel,
            np.clip(green_channel_x + green_channel_z, 0, 1),
            blue_channel
        ], axis=1)
        
        ax_grid.scatter(valid_points_2d_lowres[:, 0], valid_points_2d_lowres[:, 1], 
                      c=colors, s=2, alpha=0.8, zorder=5)

# Draw predicted trajectory on grid
if len(predicted_trajectory_lowres) > 0:
    ax_grid.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
               'r-', linewidth=2, alpha=0.8, label='Pred Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_grid.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)

# Draw origin (0, 0, 0) if visible
origin_2d = project_3d_to_2d(np.array([0.0, 0.0, 0.0]), camera_pose_world, cam_K)
if origin_2d is not None:
    origin_2d_lowres = rescale_coords(origin_2d.reshape(1, 2), H_orig, W_orig, RES_LOW, RES_LOW)[0]
    if 0 <= origin_2d_lowres[0] < RES_LOW and 0 <= origin_2d_lowres[1] < RES_LOW:
        ax_grid.plot(origin_2d_lowres[0], origin_2d_lowres[1], 'go', markersize=10, 
                   markeredgecolor='white', markeredgewidth=2, label='Origin (0,0,0)', zorder=10)

ax_grid.set_title("Groundplane Grid Lines with Trajectory", fontsize=12)
ax_grid.legend(loc='upper right', fontsize=9)
ax_grid.axis('off')

# Height chart
ax_height.clear()
timesteps = np.arange(1, len(heights_pred_denorm) + 1)
ax_height.bar(timesteps, heights_pred_denorm[:MAX_TIMESTEPS], width=0.6, alpha=0.7, color='red', label='Pred Height')
ax_height.set_xlabel('Timestep', fontsize=10)
ax_height.set_ylabel('Height (m)', fontsize=10)
ax_height.set_title('Predicted Height Trajectory', fontsize=12)
ax_height.legend(fontsize=9)
ax_height.grid(alpha=0.3)

# Gripper chart
ax_gripper.clear()
timesteps_gripper = np.arange(1, len(grippers_pred) + 1)
ax_gripper.bar(timesteps_gripper, grippers_pred[:MAX_TIMESTEPS], width=0.6, alpha=0.7, color='orange', label='Pred Gripper')
ax_gripper.set_xlabel('Timestep', fontsize=10)
ax_gripper.set_ylabel('Gripper Value', fontsize=10)
ax_gripper.set_title('Predicted Gripper Open/Close Trajectory', fontsize=12)
ax_gripper.legend(fontsize=9)
ax_gripper.grid(alpha=0.3)

plt.tight_layout()
output_path = Path(f'keypoint_testing2/test_scripts/test_ik_lifting_raw_{image_path.stem}.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved 2D visualization to {output_path}")
if not args.dont_show_plots:plt.show()
plt.close()

print("✓ 2D Visualization complete!")

# MuJoCo visualization
print("\nSetting up MuJoCo visualization...")

# Add predicted trajectory spheres to XML
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')

# Predicted Lifted trajectory spheres (red to yellow) - from PREDICTED 2D groundplane + PREDICTED height
if len(trajectory_3d) > 0:
    for i, kp_pos in enumerate(trajectory_3d):
        red = 1.0 - (i / max(len(trajectory_3d) - 1, 1)) * 0.5
        green = i / max(len(trajectory_3d) - 1, 1) * 0.5
        ET.SubElement(worldbody, 'site', {
            'name': f'pred_lifted_kp_{i}', 'type': 'sphere', 'size': '0.015',
            'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
        })

mj_model_viz = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data_viz = mujoco.MjData(mj_model_viz)

# Use the already estimated robot state (from earlier)
# We already have mj_data with the correct state, so we can copy it
mj_data_viz.qpos[:] = mj_data.qpos
mj_data_viz.ctrl[:] = mj_data.ctrl
mujoco.mj_forward(mj_model_viz, mj_data_viz)
position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses)
mujoco.mj_forward(mj_model_viz, mj_data_viz)

# Use computed camera pose and intrinsics
cam_K_for_render = cam_K

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model_viz)
ik_configuration.update(mj_data_viz.qpos)

# Hardcoded median rotation (like image_inference.py and baseline_training/test_scripts/test_ik.py)
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

# Render robot through PREDICTED LIFTED trajectory
trajectory_joints=[]
if len(trajectory_3d) > 0:
    print(f"Rendering robot through PREDICTED LIFTED trajectory ({len(trajectory_3d)} points from predicted 2D direct keypoint + predicted height)...")
    for i, target_pos in enumerate(trajectory_3d):
        target_gripper_rot = median_dataset_rotation
        
        # Update IK configuration from current state before solving
        link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
        position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
        mujoco.mj_forward(mj_model_viz, mj_data_viz)
        ik_configuration.update(mj_data_viz.qpos)
        
        # Solve IK with median rotation constraint (like image_inference.py)
        ik_to_keypoint_and_rotation(target_pos, target_gripper_rot, ik_configuration, robot_config, mj_model_viz, mj_data_viz)
        
        # Update link poses and forward kinematics after IK
        link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
        position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
        
        # Set gripper to predicted value (last joint)
        if i < len(grippers_pred):
            predicted_gripper_val = grippers_pred[i]
            # Set the last joint (gripper) to the predicted value
            if len(mj_data_viz.qpos) > 0:
                mj_data_viz.qpos[-1] = predicted_gripper_val
            if len(mj_data_viz.ctrl) > 0:
                mj_data_viz.ctrl[-1] = predicted_gripper_val
        mujoco.mj_forward(mj_model_viz, mj_data_viz)
        trajectory_joints.append((mj_data_viz.qpos.copy(), predicted_gripper_val))
        
        # Verify IK result
        kp_body_id = mujoco.mj_name2id(mj_model_viz, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        achieved_kp_pos = mj_data_viz.xpos[kp_body_id].copy()
        ik_error = np.linalg.norm(achieved_kp_pos - target_pos)
        gripper_val = grippers_pred[i] if i < len(grippers_pred) else 0.0
        if i < 3:  # Print first few for debugging
            print(f"  t={i+1}: Target=[{target_pos[0]:.4f}, {target_pos[1]:.4f}, {target_pos[2]:.4f}], "
                  f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
                  f"IK Error={ik_error:.4f}m, Gripper={gripper_val:.4f}")
        
        # Render at original resolution with cam_K (calibrated for H_orig x W_orig)
        rendered = render_from_camera_pose(mj_model_viz, mj_data_viz, camera_pose_world, cam_K_for_render, H_orig, W_orig)
        # Resize rendered to match loaded image size for display
        rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        rgb_display = cv2.resize(rgb_resized, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR) if rgb_resized.shape[:2] != (H_loaded, W_loaded) else rgb_resized
        
        # Display results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb_display, rendered_resized, (rgb_display * 0.5 + rendered_resized * 0.5).astype(np.uint8)]):
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Original RGB')
        axes[1].set_title(f'Rendered Robot (Pred Lifted t={i+1}/{len(trajectory_3d)})')
        axes[2].set_title('Overlay')
        plt.tight_layout()
        if not args.dont_show_plots:plt.show()
        plt.close()
else:
    print("⚠ No predicted lifted trajectory available for MuJoCo rendering")

import pdb;pdb.set_trace()

print("✓ MuJoCo visualization complete!")

# Output results
print("\n" + "="*60)
print("PREDICTED 3D KEYPOINTS:")
print("="*60)
for i, kp_3d in enumerate(trajectory_3d):
    gripper_val = grippers_pred[i] if i < len(grippers_pred) else 0.0
    print(f"  t={i+1:2d}: [{kp_3d[0]:8.4f}, {kp_3d[1]:8.4f}, {kp_3d[2]:8.4f}] (height={heights_pred_denorm[i]:.4f}m, gripper={gripper_val:.4f})")


print("\n" + "="*60)
print(f"Total trajectory length: {len(trajectory_3d)} keypoints")
print("="*60)


