"""Test 3D lifting: Load GT groundplane trajectory and height, visualize in 2D, then verify 3D lifting."""
import sys
import os
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX, unproject_patch_to_groundplane
from utils import project_3d_to_2d, rescale_coords, recover_3d_from_keypoint_and_height, post_process_predictions
from model import TrajectoryPredictor, MIN_HEIGHT, MAX_HEIGHT, GROUNDPLANE_X_MIN, GROUNDPLANE_X_MAX, GROUNDPLANE_Z_MIN, GROUNDPLANE_Z_MAX
import torch

# Constants
RES_LOW = 256
H_orig = 1080
W_orig = 1920
MAX_TIMESTEPS = 50
GROUNDPLANE_RANGE = 1.0

import argparse
parser = argparse.ArgumentParser(description="Test 3D lifting with GT groundplane and height trajectories")
parser.add_argument("--dataset_dir", "-d", default="scratch/parsed_judy_train", type=str, help="Dataset directory")
parser.add_argument("--episode_id", "-e", default=1, type=int, help="Episode ID")
parser.add_argument("--start_frame", "--sf", default=0, type=int, help="Start frame index")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dir = Path(f"{dataset_dir}/episode_{args.episode_id:03d}")
episode_id = episode_dir.name

# Find all frame files
frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
if len(frame_files) == 0:
    print(f"No frames found in {episode_dir}")
    exit(1)

start_idx = args.start_frame
if start_idx >= len(frame_files):
    print(f"Start frame {start_idx} out of range (max: {len(frame_files)-1})")
    exit(1)

# Load start frame
start_frame_file = frame_files[start_idx]
start_frame_str = f"{int(start_frame_file.stem):06d}"

# Load RGB image
rgb = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)
H_loaded, W_loaded = rgb.shape[:2]

# Resize RGB to original resolution for visualization consistency
rgb_resized = cv2.resize(rgb, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
rgb_lowres = cv2.resize(rgb, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Load camera pose and intrinsics
camera_pose_path = episode_dir / f"{start_frame_str}_camera_pose.npy"
cam_K_path = episode_dir / f"{start_frame_str}_cam_K.npy"
if not camera_pose_path.exists() or not cam_K_path.exists():
    print(f"Missing camera pose or intrinsics for frame {start_frame_str}")
    exit(1)

camera_pose = np.load(camera_pose_path)
cam_K = np.load(cam_K_path)

# Load GT trajectory 3D keypoints
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
trajectory_gt_3d = []
for frame_file in frame_files[start_idx + 1:]:  # All frames after start
    frame_idx_file = int(frame_file.stem)
    frame_str_file = f"{frame_idx_file:06d}"
    pose_path = episode_dir / f"{frame_str_file}_gripper_pose.npy"
    if not pose_path.exists():
        continue
    pose = np.load(pose_path)
    rot = pose[:3, :3]
    pos = pose[:3, 3]
    kp_3d = rot @ kp_local + pos
    trajectory_gt_3d.append(kp_3d)

if len(trajectory_gt_3d) == 0:
    print("No valid trajectory points found")
    exit(1)

trajectory_gt_3d = np.array(trajectory_gt_3d)
gt_trajectory_len = len(trajectory_gt_3d)  # Store original length before padding

# Project to groundplane (set Y=0, where Y is index 2)
trajectory_groundplane_3d = trajectory_gt_3d.copy()
trajectory_groundplane_3d[:, 2] = 0.0

# Project groundplane trajectory to 2D image coordinates
trajectory_points_orig = []
heights_gt = []
grippers_gt = []
for i, kp_3d_gp in enumerate(trajectory_groundplane_3d):
    kp_2d_image = project_3d_to_2d(kp_3d_gp, camera_pose, cam_K)
    if kp_2d_image is not None:
        trajectory_points_orig.append(kp_2d_image)
        # Height from original 3D keypoint (before groundplane projection)
        kp_3d_orig = trajectory_gt_3d[i]
        height_norm = np.clip((kp_3d_orig[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0)
        heights_gt.append(height_norm)
        
        # Load gripper value from joint state file
        frame_idx_file = start_idx + i + 1
        frame_str_file = f"{frame_idx_file:06d}"
        joint_state_path = episode_dir / f"{frame_str_file}.npy"
        if joint_state_path.exists():
            joint_state = np.load(joint_state_path)
            gripper_value = float(joint_state[-1])  # Last value is gripper open/close
            grippers_gt.append(gripper_value)
        else:
            grippers_gt.append(0.0)

trajectory_points_orig = np.array(trajectory_points_orig) if len(trajectory_points_orig) > 0 else None
heights_gt = np.array(heights_gt) if len(heights_gt) > 0 else None
grippers_gt = np.array(grippers_gt) if len(grippers_gt) > 0 else None

if trajectory_points_orig is None or len(trajectory_points_orig) == 0:
    print("No valid trajectory points found")
    exit(1)

# Pad GT trajectory to MAX_TIMESTEPS for comparison with predictions
traj_len = len(trajectory_points_orig)
if traj_len < MAX_TIMESTEPS:
    last_point = trajectory_points_orig[-1]
    last_height = heights_gt[-1] if len(heights_gt) > 0 else 0.0
    last_gripper = grippers_gt[-1] if len(grippers_gt) > 0 else 0.0
    trajectory_points_orig = np.concatenate([
        trajectory_points_orig,
        np.tile(last_point.reshape(1, -1), (MAX_TIMESTEPS - traj_len, 1))
    ])
    heights_gt = np.concatenate([
        heights_gt,
        np.full(MAX_TIMESTEPS - traj_len, last_height)
    ])
    grippers_gt = np.concatenate([
        grippers_gt,
        np.full(MAX_TIMESTEPS - traj_len, last_gripper)
    ])
elif traj_len > MAX_TIMESTEPS:
    trajectory_points_orig = trajectory_points_orig[:MAX_TIMESTEPS]
    heights_gt = heights_gt[:MAX_TIMESTEPS]
    grippers_gt = grippers_gt[:MAX_TIMESTEPS]

# Denormalize heights (after padding)
heights_gt_denorm = heights_gt * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT

# Rescale trajectory to low-res for visualization
trajectory_points_lowres = rescale_coords(trajectory_points_orig, H_orig, W_orig, RES_LOW, RES_LOW)

# ========== MODEL INFERENCE ==========
print("\nRunning model inference...")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load DINO features
dino_features_path = episode_dir / f"dino_features_{start_frame_str}.pt"
if not dino_features_path.exists():
    print(f"DINO features not found: {dino_features_path}")
    exit(1)

dino_features = torch.load(dino_features_path, map_location='cpu', weights_only=False)
if isinstance(dino_features, np.ndarray):
    dino_features = torch.from_numpy(dino_features)

# Handle different possible shapes of dino features
if dino_features.dim() == 2:
    # Flattened format: (H*W, D) - need to infer dimensions
    num_pixels = dino_features.shape[0]
    H_patches = int(np.sqrt(num_pixels))
    W_patches = num_pixels // H_patches
    if H_patches * W_patches != num_pixels:
        for h in range(int(np.sqrt(num_pixels)), 0, -1):
            if num_pixels % h == 0:
                H_patches = h
                W_patches = num_pixels // h
                break
    dino_features = dino_features.view(H_patches, W_patches, -1)
elif dino_features.dim() == 3:
    if dino_features.shape[0] < dino_features.shape[2]:  # (D, H, W)
        dino_features = dino_features.permute(1, 2, 0)  # (H, W, D)

H_patches, W_patches = dino_features.shape[0], dino_features.shape[1]
dino_feat_dim = dino_features.shape[2]
print(f"DINO features shape: {dino_features.shape}")
print(f"Patch resolution: {H_patches}x{W_patches}")

# Extract first 3 channels for visualization and normalize
dino_vis = dino_features[:, :, :3].numpy()  # (H_patches, W_patches, 3)
for i in range(3):
    channel = dino_vis[:, :, i]
    min_val, max_val = channel.min(), channel.max()
    if max_val > min_val:
        dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
    else:
        dino_vis[:, :, i] = 0.5
dino_vis = np.clip(dino_vis, 0, 1)

# Flatten DINO features for model input
dino_tokens_flat = dino_features.view(H_patches * W_patches, dino_feat_dim).float()

# Compute ground-plane coordinates for each patch
groundplane_coords_list = []
y_coords, x_coords = np.meshgrid(np.arange(H_patches), np.arange(W_patches), indexing='ij')
for patch_y, patch_x in zip(y_coords.flatten(), x_coords.flatten()):
    x_gp, z_gp = unproject_patch_to_groundplane(
        patch_x, patch_y, H_patches, W_patches, H_orig, W_orig,
        camera_pose, cam_K, ground_y=0.0
    )
    if x_gp is not None and z_gp is not None:
        groundplane_coords_list.append([x_gp, z_gp])
    else:
        groundplane_coords_list.append([0.0, 0.0])

groundplane_coords = torch.from_numpy(np.array(groundplane_coords_list, dtype=np.float32))

# Compute current EEF 2D position and height (from start frame)
start_pose_path = episode_dir / f"{start_frame_str}_gripper_pose.npy"
if not start_pose_path.exists():
    print("Current gripper pose not found")
    exit(1)

start_pose = np.load(start_pose_path)
start_rot = start_pose[:3, :3]
start_pos = start_pose[:3, 3]
current_kp_3d = start_rot @ kp_local + start_pos

# Project current EEF to 2D
current_kp_2d_image = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
if current_kp_2d_image is None:
    print("Current EEF projection failed")
    exit(1)

# Rescale to patch coordinates
current_kp_2d_patches = rescale_coords(
    current_kp_2d_image.reshape(1, 2),
    H_orig, W_orig,
    H_patches, W_patches
)[0]

# Find nearest patch for current EEF
current_eef_patch_x = int(np.round(np.clip(current_kp_2d_patches[0], 0, W_patches - 1)))
current_eef_patch_y = int(np.round(np.clip(current_kp_2d_patches[1], 0, H_patches - 1)))
current_eef_patch_idx = current_eef_patch_y * W_patches + current_eef_patch_x

# Get current EEF height (normalized)
current_eef_height_norm = float(np.clip((current_kp_3d[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0))

# Load model
model_path = Path("groundplane_testing/tmpstorage/model.pt")
if not model_path.exists():
    print(f"Model not found at {model_path}")
    exit(1)

model = TrajectoryPredictor(
    dino_feat_dim=dino_feat_dim,
    max_timesteps=MAX_TIMESTEPS,
    num_layers=3,
    num_heads=4,
    hidden_dim=128,
    num_pos_bands=4,
    groundplane_range=GROUNDPLANE_RANGE
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"✓ Loaded model from {model_path}")
model.eval()

# Run inference
print("Running inference...")
with torch.no_grad():
    dino_tokens_batch = dino_tokens_flat.unsqueeze(0).to(device)  # (1, num_patches, dino_feat_dim)
    groundplane_coords_batch = groundplane_coords.unsqueeze(0).to(device)  # (1, num_patches, 2)
    current_eef_patch_idx_batch = torch.tensor([current_eef_patch_idx], dtype=torch.long).to(device)
    current_eef_height_batch = torch.tensor([current_eef_height_norm], dtype=torch.float32).to(device)
    
    attention_scores, heights_pred, grippers_pred = model(
        dino_tokens_batch,
        groundplane_coords_batch,
        current_eef_patch_idx_batch,
        current_eef_height_batch,
        use_attention_mask=True  # Use groundplane range mask during inference
    )
    
    attention_scores = attention_scores.squeeze(0).cpu().numpy()  # (max_timesteps, num_patches)
    heights_pred = heights_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)
    grippers_pred = grippers_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)

print("✓ Inference complete")

# Post-process predictions to get predicted 2D groundplane coordinates and heights
_, pred_image_coords, heights_pred_denorm = post_process_predictions(
    attention_scores, heights_pred, H_patches, W_patches, H_orig, W_orig, camera_pose, cam_K
)

print(f"✓ Predicted {len(pred_image_coords)} trajectory points")

# Rescale predicted trajectory to low-res for visualization
predicted_trajectory_lowres = rescale_coords(pred_image_coords, H_orig, W_orig, RES_LOW, RES_LOW)

# Test 3D lifting: lift PREDICTED 2D groundplane + PREDICTED height back to 3D
trajectory_pred_lifted_3d = []
for i in range(len(pred_image_coords)):
    kp_2d = pred_image_coords[i]
    height = heights_pred_denorm[i]
    kp_3d_lifted = recover_3d_from_keypoint_and_height(kp_2d, height, camera_pose, cam_K)
    if kp_3d_lifted is not None:
        trajectory_pred_lifted_3d.append(kp_3d_lifted)

trajectory_pred_lifted_3d = np.array(trajectory_pred_lifted_3d) 

# Test 3D lifting: lift 2D groundplane + height back to 3D
trajectory_lifted_3d = []
for i in range(len(trajectory_points_orig)):
    kp_2d = trajectory_points_orig[i]
    height = heights_gt_denorm[i]
    kp_3d_lifted = recover_3d_from_keypoint_and_height(kp_2d, height, camera_pose, cam_K)
    if kp_3d_lifted is not None:
        trajectory_lifted_3d.append(kp_3d_lifted)

trajectory_lifted_3d = np.array(trajectory_lifted_3d) if len(trajectory_lifted_3d) > 0 else None

# Compute lifting errors (only for valid GT trajectory length, not padded portion)
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    # Only compare up to the original GT trajectory length
    valid_len = min(len(trajectory_lifted_3d), gt_trajectory_len)
    lifting_errors = np.linalg.norm(trajectory_lifted_3d[:valid_len] - trajectory_gt_3d[:valid_len], axis=1)
    print(f"3D Lifting Verification (GT):")
    print(f"  Mean error: {np.mean(lifting_errors):.6f}m")
    print(f"  Max error: {np.max(lifting_errors):.6f}m")
    print(f"  Min error: {np.min(lifting_errors):.6f}m")
    if len(lifting_errors) > 0:
        print(f"  First point error: {lifting_errors[0]:.6f}m")
        print(f"    GT 3D: {trajectory_gt_3d[0]}")
        print(f"    Lifted 3D: {trajectory_lifted_3d[0]}")
        print(f"    GT 2D groundplane: {trajectory_points_orig[0]}")
        print(f"    GT height: {heights_gt_denorm[0]:.4f}m")

# Create dense groundplane coordinate map (like render_dense_groundplane.py)
print("Computing dense groundplane coordinate map...")
u, v = np.meshgrid(np.arange(W_orig), np.arange(H_orig))
u_flat = u.flatten()
v_flat = v.flatten()

# Convert pixels to normalized camera coordinates (vectorized)
K_inv = np.linalg.inv(cam_K)
pixels_h = np.stack([u_flat, v_flat, np.ones(len(u_flat))], axis=1).T  # (3, N)
rays_cam = (K_inv @ pixels_h).T  # (N, 3)

# Transform rays to world coordinates
cam_pose_inv = np.linalg.inv(camera_pose)
rays_world = (cam_pose_inv[:3, :3] @ rays_cam.T).T  # (N, 3) - direction vectors
cam_pos = cam_pose_inv[:3, 3]  # (3,)

# Find intersection with ground plane (Y = 0, where Y is index 2)
ray_dir_z = rays_world[:, 2]  # Z component of ray direction (Y is index 2)
t = (0.0 - cam_pos[2]) / (ray_dir_z + 1e-10)  # Avoid division by zero

# Valid intersections: t > 0 and not parallel to plane
valid_mask_flat = (t > 0) & (np.abs(ray_dir_z) > 1e-6)

# Compute 3D points on ground plane
points_3d_flat = cam_pos[None, :] + t[:, None] * rays_world  # (N, 3)

# Reshape to image dimensions
points_3d = points_3d_flat.reshape(H_orig, W_orig, 3)
valid_mask = valid_mask_flat.reshape(H_orig, W_orig)

# Extract X and Z coordinates
x_coords = points_3d[:, :, 0]  # X coordinate
z_coords = points_3d[:, :, 1]  # Z coordinate

# Normalize coordinates to [0, 1] for color mapping
grid_range = 1.0
x_norm = np.clip((x_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)
z_norm = np.clip((z_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)

# Create RGB colors based on XZ coordinates
red_channel = 1.0 - x_norm * 0.5
green_channel_x = x_norm
blue_channel = 1.0 - z_norm
green_channel_z = z_norm

coord_colors = np.stack([
    red_channel,
    np.clip(green_channel_x + green_channel_z, 0, 1),
    blue_channel
], axis=2)  # (H_orig, W_orig, 3)

# Set invalid pixels to white
coord_colors[~valid_mask] = 1.0

# Resize coord_colors to low-res for visualization
coord_colors_lowres = cv2.resize(coord_colors, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Create visualization
fig = plt.figure(figsize=(20, 14))
gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3, height_ratios=[1, 0.4, 0.4])

# Row 1: RGB with trajectory, Dense groundplane map, Groundplane trajectory overlay
ax_rgb = fig.add_subplot(gs[0, 0])
ax_groundplane = fig.add_subplot(gs[0, 1])
ax_overlay = fig.add_subplot(gs[0, 2])
ax_height = fig.add_subplot(gs[1, :])  # Height chart spans all columns
ax_gripper = fig.add_subplot(gs[2, :])  # Gripper chart spans all columns

# RGB with groundplane trajectory (GT and Predicted)
ax_rgb.imshow(rgb_lowres)
if len(trajectory_points_lowres) > 0:
    ax_rgb.plot(trajectory_points_lowres[:, 0], trajectory_points_lowres[:, 1], 
               'w-', linewidth=2, alpha=0.8, label='GT Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(trajectory_points_lowres):
        color = plt.cm.viridis(i / len(trajectory_points_lowres))
        ax_rgb.plot(x, y, 'o', color=color, markersize=6, markeredgecolor='white', 
                   markeredgewidth=1, zorder=11)
if len(predicted_trajectory_lowres) > 0:
    ax_rgb.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
               'r-', linewidth=2, alpha=0.8, label='Pred Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_rgb.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=2, zorder=11)
ax_rgb.set_title("RGB with Groundplane Trajectory", fontsize=12)
ax_rgb.legend(loc='upper right', fontsize=9)
ax_rgb.axis('off')

# Dense groundplane coordinate map
ax_groundplane.imshow(coord_colors_lowres)
ax_groundplane.set_title("Dense Groundplane Coordinate Map", fontsize=12)
ax_groundplane.axis('off')

# Overlay: RGB + groundplane coordinates + trajectory (GT and Predicted)
rgb_normalized = rgb_lowres.astype(np.float32) / 255.0
overlay = 0.5 * rgb_normalized + 0.5 * coord_colors_lowres
ax_overlay.imshow(overlay)
if len(trajectory_points_lowres) > 0:
    ax_overlay.plot(trajectory_points_lowres[:, 0], trajectory_points_lowres[:, 1], 
                   'w-', linewidth=3, alpha=0.9, label='GT Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(trajectory_points_lowres):
        color = plt.cm.viridis(i / len(trajectory_points_lowres))
        ax_overlay.plot(x, y, 'o', color=color, markersize=8, markeredgecolor='white', 
                       markeredgewidth=2, zorder=11)
if len(predicted_trajectory_lowres) > 0:
    ax_overlay.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 
                   'r-', linewidth=3, alpha=0.9, label='Pred Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(predicted_trajectory_lowres):
        color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
        ax_overlay.plot(x, y, 'x', color=color, markersize=8, markeredgewidth=2, zorder=11)
ax_overlay.set_title("Overlay: RGB + Groundplane + Trajectory", fontsize=12)
ax_overlay.legend(loc='upper right', fontsize=9)
ax_overlay.axis('off')

# Height chart (GT and Predicted)
timesteps = np.arange(1, MAX_TIMESTEPS + 1)
heights_gt_denorm_padded = heights_gt * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
if len(heights_gt_denorm_padded) < MAX_TIMESTEPS:
    # Pad GT heights to match MAX_TIMESTEPS
    last_height = heights_gt_denorm_padded[-1] if len(heights_gt_denorm_padded) > 0 else 0.0
    heights_gt_denorm_padded = np.concatenate([
        heights_gt_denorm_padded,
        np.full(MAX_TIMESTEPS - len(heights_gt_denorm_padded), last_height)
    ])
elif len(heights_gt_denorm_padded) > MAX_TIMESTEPS:
    heights_gt_denorm_padded = heights_gt_denorm_padded[:MAX_TIMESTEPS]

ax_height.bar(timesteps - 0.2, heights_gt_denorm_padded, width=0.4, alpha=0.7, color='green', label='GT Height')
if len(heights_pred_denorm) > 0:
    ax_height.bar(timesteps + 0.2, heights_pred_denorm[:MAX_TIMESTEPS], width=0.4, alpha=0.7, color='red', label='Pred Height')
ax_height.set_xlabel('Timestep', fontsize=10)
ax_height.set_ylabel('Height (m)', fontsize=10)
ax_height.set_title('Height Trajectory (GT vs Predicted)', fontsize=12)
ax_height.legend(fontsize=9)
ax_height.grid(alpha=0.3)

# Gripper chart (GT and Predicted)
grippers_gt_padded = grippers_gt.copy()
if len(grippers_gt_padded) < MAX_TIMESTEPS:
    # Pad GT grippers to match MAX_TIMESTEPS
    last_gripper = grippers_gt_padded[-1] if len(grippers_gt_padded) > 0 else 0.0
    grippers_gt_padded = np.concatenate([
        grippers_gt_padded,
        np.full(MAX_TIMESTEPS - len(grippers_gt_padded), last_gripper)
    ])
elif len(grippers_gt_padded) > MAX_TIMESTEPS:
    grippers_gt_padded = grippers_gt_padded[:MAX_TIMESTEPS]

ax_gripper.bar(timesteps - 0.2, grippers_gt_padded, width=0.4, alpha=0.7, color='blue', label='GT Gripper')
if len(grippers_pred) > 0:
    ax_gripper.bar(timesteps + 0.2, grippers_pred[:MAX_TIMESTEPS], width=0.4, alpha=0.7, color='orange', label='Pred Gripper')
ax_gripper.set_xlabel('Timestep', fontsize=10)
ax_gripper.set_ylabel('Gripper Value', fontsize=10)
ax_gripper.set_title('Gripper Open/Close Trajectory (GT vs Predicted)', fontsize=12)
ax_gripper.legend(fontsize=9)
ax_gripper.grid(alpha=0.3)

# Add text summary
summary_parts = []
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    summary_parts.append(f"GT 3D Lifting: Mean Error = {np.mean(lifting_errors):.6f}m")
if trajectory_pred_lifted_3d is not None and len(trajectory_pred_lifted_3d) > 0:
    # Only compare up to the original GT trajectory length
    valid_len_pred = min(len(trajectory_pred_lifted_3d), gt_trajectory_len)
    pred_lifting_errors = np.linalg.norm(trajectory_pred_lifted_3d[:valid_len_pred] - trajectory_gt_3d[:valid_len_pred], axis=1)
    summary_parts.append(f"Pred 3D Lifting: Mean Error = {np.mean(pred_lifting_errors):.6f}m")
if summary_parts:
    summary_text = " | ".join(summary_parts)
    fig.text(0.5, 0.02, summary_text, ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
output_path = Path(f'groundplane_testing/test_scripts/test_ik_lifting_{episode_id}_frame{start_idx}.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved 2D visualization to {output_path}")
plt.show()

print("✓ 2D Visualization complete!")

# MuJoCo visualization (from test_ik.py)
print("\nSetting up MuJoCo visualization...")
import mujoco
import xml.etree.ElementTree as ET
from exo_utils import render_from_camera_pose
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, get_link_poses_from_robot
import mink

# Add GT and lifted trajectory spheres to XML for comparison
robot_config = SO100AdhesiveConfig()
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')

# GT trajectory spheres (green to blue)
for i, kp_pos in enumerate(trajectory_gt_3d):
    green = 1.0 - (i / max(len(trajectory_gt_3d) - 1, 1))
    blue = i / max(len(trajectory_gt_3d) - 1, 1)
    ET.SubElement(worldbody, 'site', {
        'name': f'gt_kp_{i}', 'type': 'sphere', 'size': '0.015',
        'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'0 {green} {blue} 0.8'
    })

# GT Lifted trajectory spheres (green to blue) - from GT 2D groundplane + GT height
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    for i, kp_pos in enumerate(trajectory_lifted_3d):
        green = 1.0 - (i / max(len(trajectory_lifted_3d) - 1, 1)) * 0.5
        blue = i / max(len(trajectory_lifted_3d) - 1, 1) * 0.5
        ET.SubElement(worldbody, 'site', {
            'name': f'gt_lifted_kp_{i}', 'type': 'sphere', 'size': '0.015',
            'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'0 {green} {blue} 0.8'
        })

# Predicted Lifted trajectory spheres (red to yellow) - from PREDICTED 2D groundplane + PREDICTED height
if trajectory_pred_lifted_3d is not None and len(trajectory_pred_lifted_3d) > 0:
    for i, kp_pos in enumerate(trajectory_pred_lifted_3d):
        red = 1.0 - (i / max(len(trajectory_pred_lifted_3d) - 1, 1)) * 0.5
        green = i / max(len(trajectory_pred_lifted_3d) - 1, 1) * 0.5
        ET.SubElement(worldbody, 'site', {
            'name': f'pred_lifted_kp_{i}', 'type': 'sphere', 'size': '0.015',
            'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
        })

mj_model = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data = mujoco.MjData(mj_model)

# Estimate robot state from image
print("Estimating robot state from image...")
link_poses, camera_pose_world, cam_K_estimated, _, _, _ = detect_and_set_link_poses(rgb_resized, mj_model, mj_data, robot_config)
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# Use saved camera pose and intrinsics (more accurate than estimated)
camera_pose_world = camera_pose
cam_K_for_render = cam_K

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model)
ik_configuration.update(mj_data.qpos)

# IK function to follow keypoint trajectory
def ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data):
    for _ in range(50):
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        configuration.update(mj_data.qpos)
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.1)
        target_quat = R.from_matrix(kp_rot).as_quat()
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[target_quat[3], target_quat[0], target_quat[1], target_quat[2]], target_pos])))
        posture_task = mink.PostureTask(mj_model, cost=1e-3)
        posture_task.set_target(mj_data.qpos)
        vel = mink.solve_ik(configuration, [kp_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
        configuration.integrate_inplace(vel, 0.01)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_step(mj_model, mj_data)

# Render robot through PREDICTED LIFTED trajectory (from PREDICTED 2D groundplane + PREDICTED height)
if trajectory_pred_lifted_3d is not None and len(trajectory_pred_lifted_3d) > 0:
    print(f"Rendering robot through PREDICTED LIFTED trajectory ({len(trajectory_pred_lifted_3d)} points from predicted 2D groundplane + predicted height)...")
    for i, target_pos in enumerate(trajectory_pred_lifted_3d):
        ik_to_keypoint(target_pos, ik_configuration, robot_config, mj_model, mj_data)
        
        # Verify IK result
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        achieved_kp_pos = mj_data.xpos[kp_body_id].copy()
        ik_error = np.linalg.norm(achieved_kp_pos - target_pos)
        if i < 3:  # Print first few for debugging
            print(f"  t={i+1}: Target=[{target_pos[0]:.4f}, {target_pos[1]:.4f}, {target_pos[2]:.4f}], "
                  f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
                  f"IK Error={ik_error:.4f}m")
        
        # Render at original resolution with cam_K (calibrated for H_orig x W_orig)
        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K_for_render, H_orig, W_orig)
        # Resize rendered to match loaded image size for display
        rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        rgb_display = cv2.resize(rgb_resized, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR) if rgb_resized.shape[:2] != (H_loaded, W_loaded) else rgb_resized
        
        # Display results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb_display, rendered_resized, (rgb_display * 0.5 + rendered_resized * 0.5).astype(np.uint8)]):
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Original RGB')
        axes[1].set_title(f'Rendered Robot (Pred Lifted t={i+1}/{len(trajectory_pred_lifted_3d)})')
        axes[2].set_title('Overlay')
        plt.tight_layout()
        plt.show()
else:
    print("⚠ No predicted lifted trajectory available for MuJoCo rendering")

print("✓ MuJoCo visualization complete!")
