" ""Run inference on a single image: estimate robot state, compute DINO features, and predict trajectory."""
import sys
import os
from pathlib import Path
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from PIL import Image
from torchvision.transforms import functional as TF
import pickle
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import mujoco
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, render_from_camera_pose
from model import TrajectoryPredictor, MIN_HEIGHT, MAX_HEIGHT
from utils import project_3d_to_2d, rescale_coords, post_process_predictions
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX, unproject_patch_to_groundplane

# DINO configuration
REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12

# Model configuration
MAX_TIMESTEPS = 50
GROUNDPLANE_RANGE = 1.0
RES_LOW = 224

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Parse arguments
parser = argparse.ArgumentParser(description='Run inference on a single image')
parser.add_argument('--image_path', '-i', type=str, required=True, help='Path to input image')
parser.add_argument('--vis_2d', action='store_true', help='Show 2D visualization (like vis_eval_2d.py)')
parser.add_argument('--vis_mujoco', action='store_true', help='Show MuJoCo visualization (like test_ik.py)')
parser.add_argument('--vis_dino', action='store_true', help='Show DINO feature visualization')
parser.add_argument('--camera_pose', type=str, default=None, help='Path to saved camera pose .npy file (optional, will use ArUco detection if not provided)')
parser.add_argument('--cam_K', type=str, default=None, help='Path to saved camera intrinsics .npy file (optional, will use ArUco detection if not provided)')
args = parser.parse_args()

# If no visualization flags are set, show 2D visualization by default
if not (args.vis_2d or args.vis_mujoco or args.vis_dino):
    args.vis_2d = True
    print("No visualization flags specified, showing 2D visualization by default")

image_path = Path(args.image_path)
if not image_path.exists():
    print(f"Error: Image not found at {image_path}")
    exit(1)

print(f"Loading image: {image_path}")

# Load RGB image
rgb = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)
H_loaded, W_loaded = rgb.shape[:2]
print(f"Image resolution: {W_loaded}x{H_loaded}")

# Setup robot
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Try to automatically find saved camera pose and intrinsics in the same directory
# Check if image is in a dataset episode directory structure
image_dir = image_path.parent
image_stem = image_path.stem
camera_pose_path = None
cam_K_path = None

# Hardcode original image resolution (before downsampling) - same as vis_eval_2d.py
H_orig = 1080
W_orig = 1920

# Try to find camera pose and intrinsics files
# Pattern 1: Same directory, first frame format (000000_camera_pose.npy)
if image_stem.isdigit():
    frame_str = f"{int(image_stem):06d}"
    camera_pose_path_candidate = image_dir / f"{frame_str}_camera_pose.npy"
    cam_K_path_candidate = image_dir / f"{frame_str}_cam_K.npy"
    if camera_pose_path_candidate.exists() and cam_K_path_candidate.exists():
        camera_pose_path = camera_pose_path_candidate
        cam_K_path = cam_K_path_candidate
        print(f"Found saved camera data in same directory: {camera_pose_path.name}, {cam_K_path.name}")

# Pattern 2: First frame in episode (000000_camera_pose.npy)
if camera_pose_path is None:
    first_frame_camera_pose = image_dir / "000000_camera_pose.npy"
    first_frame_cam_K = image_dir / "000000_cam_K.npy"
    if first_frame_camera_pose.exists() and first_frame_cam_K.exists():
        camera_pose_path = first_frame_camera_pose
        cam_K_path = first_frame_cam_K
        print(f"Found saved camera data (first frame): {camera_pose_path.name}, {cam_K_path.name}")

# Load or estimate camera pose and intrinsics
if args.camera_pose and args.cam_K:
    # Use saved camera pose and intrinsics
    print(f"Loading camera pose from {args.camera_pose}")
    camera_pose_world = np.load(args.camera_pose)
    print(f"Loading camera intrinsics from {args.cam_K}")
    cam_K = np.load(args.cam_K)
    
    # cam_K from saved file is calibrated for original resolution (H_orig x W_orig = 1080x1920)
    # Use it directly at original resolution for all 3D operations (like vis_eval_2d.py)
    # This ensures consistency with training data and correct 3D lifting
    print(f"  Using saved cam_K as-is (calibrated for {W_orig}x{H_orig})")
    
    # Still need to estimate robot state from image for current EEF position
    print("Estimating robot state from image (for current EEF position)...")
    try:
        link_poses, _, _, _, _, _ = detect_and_set_link_poses(rgb, mj_model, mj_data, robot_config, cam_K=cam_K)
        configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_forward(mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        print("✓ Estimated robot state from image")
    except Exception as e:
        print(f"Error estimating robot state: {e}")
        exit(1)
elif camera_pose_path and cam_K_path:
    # Use automatically found saved camera pose and intrinsics
    print(f"Loading camera pose from {camera_pose_path}")
    camera_pose_world = np.load(camera_pose_path)
    print(f"Loading camera intrinsics from {cam_K_path}")
    cam_K = np.load(cam_K_path)
    
    # cam_K from saved file is calibrated for original resolution (H_orig x W_orig = 1080x1920)
    # Use it directly at original resolution for all 3D operations (like vis_eval_2d.py)
    # This ensures consistency with training data and correct 3D lifting
    print(f"  Using saved cam_K as-is (calibrated for {W_orig}x{H_orig})")
    
    # Still need to estimate robot state from image for current EEF position
    print("Estimating robot state from image (for current EEF position)...")
    try:
        link_poses, _, _, _, _, _ = detect_and_set_link_poses(rgb, mj_model, mj_data, robot_config, cam_K=cam_K)
        configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_forward(mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        print("✓ Estimated robot state from image")
    except Exception as e:
        print(f"Error estimating robot state: {e}")
        exit(1)
else:
    # Estimate camera pose and intrinsics from ArUco detection
    print("Estimating robot state and camera pose from image (ArUco detection)...")
    print("  (No saved camera pose/intrinsics found - using ArUco detection)")
    try:
        link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(rgb, mj_model, mj_data, robot_config)
        configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_forward(mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        print("✓ Estimated robot state and camera pose from image")
    except Exception as e:
        print(f"Error estimating robot state: {e}")
        exit(1)

# Validate cam_K - it should be calibrated for H_orig x W_orig
cx, cy = cam_K[0, 2], cam_K[1, 2]
fx, fy = cam_K[0, 0], cam_K[1, 1]
print(f"\ncam_K validation:")
print(f"  Focal lengths: fx={fx:.1f}, fy={fy:.1f}")
print(f"  Principal point: cx={cx:.1f}, cy={cy:.1f}")
print(f"  Calibrated for: {W_orig}x{H_orig}")
print(f"  Expected principal point: ~({W_orig/2:.1f}, {H_orig/2:.1f})")

# Warn if principal point is far from center (might indicate calibration issue)
if abs(cx - W_orig/2) > W_orig * 0.2 or abs(cy - H_orig/2) > H_orig * 0.2:
    print(f"  WARNING: Principal point is far from image center - cam_K might be miscalibrated!")

# cam_K is calibrated for H_orig x W_orig (original resolution)
# We'll use it at that resolution for all 3D operations, like vis_eval_2d.py
camera_pose = camera_pose_world

# Get current gripper pose (use saved gripper pose like vis_eval_2d.py, not estimated)
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
try:
    # Try to load saved gripper pose from the same directory as the image
    image_stem = image_path.stem
    if image_stem.isdigit():
        frame_str = f"{int(image_stem):06d}"
        gripper_pose_path = image_dir / f"{frame_str}_gripper_pose.npy"
        if gripper_pose_path.exists():
            start_pose = np.load(gripper_pose_path)
            print(f"Using saved gripper pose from {gripper_pose_path.name}")
        else:
            # Fallback to first frame gripper pose
            first_frame_gripper_pose = image_dir / "000000_gripper_pose.npy"
            if first_frame_gripper_pose.exists():
                start_pose = np.load(first_frame_gripper_pose)
                print(f"Using saved gripper pose from first frame: {first_frame_gripper_pose.name}")
            else:
                # Fallback to estimated pose
                print("No saved gripper pose found, using estimated pose")
                gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
                gripper_pos = mj_data.xpos[gripper_body_id].copy()
                gripper_quat = mj_data.xquat[gripper_body_id].copy()
                from scipy.spatial.transform import Rotation as R
                gripper_rot = R.from_quat(gripper_quat[[1, 2, 3, 0]]).as_matrix()
                start_pose = np.eye(4)
                start_pose[:3, :3] = gripper_rot
                start_pose[:3, 3] = gripper_pos
    else:
        raise FileNotFoundError("Image stem is not a digit")
except Exception as e:
    # Fallback to estimated pose
    print(f"Could not load saved gripper pose ({e}), using estimated pose")
    gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
    gripper_pos = mj_data.xpos[gripper_body_id].copy()
    gripper_quat = mj_data.xquat[gripper_body_id].copy()
    from scipy.spatial.transform import Rotation as R
    gripper_rot = R.from_quat(gripper_quat[[1, 2, 3, 0]]).as_matrix()
    start_pose = np.eye(4)
    start_pose[:3, :3] = gripper_rot
    start_pose[:3, 3] = gripper_pos

# Compute current EEF 3D position from gripper pose (like vis_eval_2d.py)
start_rot = start_pose[:3, :3]
start_pos = start_pose[:3, 3]
current_kp_3d = start_rot @ kp_local + start_pos
current_eef_height_norm = float(np.clip((current_kp_3d[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0))

# Compute DINO features
print("Computing DINO features...")
sys.path.append(REPO_DIR)
dinov3_model = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
dinov3_model.eval()

# Load PCA embedder
pca_path = Path("scratch/dino_pca_embedder.pkl")
if not pca_path.exists():
    print(f"PCA embedder not found at {pca_path}")
    exit(1)
with open(pca_path, 'rb') as f:
    pca_data = pickle.load(f)
    pca_embedder = pca_data['pca']

# Process image for DINO
def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))

img_pil = Image.fromarray(rgb).convert("RGB")
image_resized = resize_transform(img_pil)
image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)

# Extract DINO features
with torch.no_grad():
    with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
        feats = dinov3_model.get_intermediate_layers(
            image_resized_norm.unsqueeze(0).to(device),
            n=range(N_LAYERS),
            reshape=True,
            norm=True
        )
        x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
        dim = x.shape[0]
        x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
        
        # Apply PCA
        pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
        
        # Get patch resolution
        h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
        pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)

print(f"✓ Computed DINO features: {h_patches}x{w_patches} patches")

# Convert to tensor format
dino_features = torch.from_numpy(pca_features_patches).float()  # (H_patches, W_patches, 32)
dino_tokens_flat = dino_features.view(h_patches * w_patches, 32).float()

# Compute ground-plane coordinates for each patch
groundplane_coords_list = []
y_coords, x_coords = np.meshgrid(np.arange(h_patches), np.arange(w_patches), indexing='ij')
for patch_y, patch_x in zip(y_coords.flatten(), x_coords.flatten()):
    # Use H_orig, W_orig for unprojection (like vis_eval_2d.py)
    # cam_K is calibrated for original resolution, so use it at that resolution
    x_gp, z_gp = unproject_patch_to_groundplane(
        patch_x, patch_y, h_patches, w_patches, H_orig, W_orig,
        camera_pose, cam_K, ground_y=0.0
    )
    if x_gp is not None and z_gp is not None:
        groundplane_coords_list.append([x_gp, z_gp])
    else:
        groundplane_coords_list.append([0.0, 0.0])

groundplane_coords = torch.from_numpy(np.array(groundplane_coords_list, dtype=np.float32))

# Compute current EEF 2D position in patch coordinates
# Use H_orig, W_orig for projection (like vis_eval_2d.py)
# cam_K is calibrated for original resolution
current_kp_2d_image = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
if current_kp_2d_image is None:
    print("Current EEF projection failed")
    exit(1)

# Rescale from image coordinates (H_orig x W_orig) to patch coordinates
current_kp_2d_patches = rescale_coords(
    current_kp_2d_image.reshape(1, 2),
    H_orig, W_orig,
    h_patches, w_patches
)[0]

current_eef_patch_x = int(np.round(np.clip(current_kp_2d_patches[0], 0, w_patches - 1)))
current_eef_patch_y = int(np.round(np.clip(current_kp_2d_patches[1], 0, h_patches - 1)))
current_eef_patch_idx = current_eef_patch_y * w_patches + current_eef_patch_x

# Load model
model_path = Path("groundplane_testing/tmpstorage/model.pt")
if not model_path.exists():
    print(f"Model not found at {model_path}")
    exit(1)

model = TrajectoryPredictor(
    dino_feat_dim=32,
    max_timesteps=MAX_TIMESTEPS,
    num_layers=3,
    num_heads=4,
    hidden_dim=128,
    num_pos_bands=4,
    groundplane_range=GROUNDPLANE_RANGE
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"✓ Loaded model from {model_path}")
model.eval()

# Run inference
print("Running inference...")
with torch.no_grad():
    dino_tokens_batch = dino_tokens_flat.unsqueeze(0).to(device)  # (1, num_patches, 32)
    groundplane_coords_batch = groundplane_coords.unsqueeze(0).to(device)  # (1, num_patches, 2)
    current_eef_patch_idx_batch = torch.tensor([current_eef_patch_idx], dtype=torch.long).to(device)
    current_eef_height_batch = torch.tensor([current_eef_height_norm], dtype=torch.float32).to(device)
    
    attention_scores, heights_pred = model(
        dino_tokens_batch,
        groundplane_coords_batch,
        current_eef_patch_idx_batch,
        current_eef_height_batch
    )
    
    attention_scores = attention_scores.squeeze(0).cpu().numpy()  # (max_timesteps, num_patches)
    heights_pred = heights_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)

print("✓ Inference complete")

# Post-process predictions
# IMPORTANT: Use H_orig, W_orig for post_process_predictions (like vis_eval_2d.py)
# cam_K is calibrated for original resolution, so use it at that resolution
# The pred_image_coords will be in H_orig x W_orig coordinates, which matches cam_K
trajectory_pred_3d, pred_image_coords, heights_pred_denorm = post_process_predictions(
    attention_scores, heights_pred, h_patches, w_patches, H_orig, W_orig, camera_pose, cam_K
)

print(f"✓ Predicted {len(trajectory_pred_3d)} trajectory points")

# Debug: Show which patch was selected and compare with GT
if len(trajectory_pred_3d) > 0:
    pred_patch_idx = attention_scores[0].argmax()
    pred_py = pred_patch_idx // w_patches
    pred_px = pred_patch_idx % w_patches
    print(f"\nFirst prediction patch selection:")
    print(f"  Selected patch index: {pred_patch_idx}")
    print(f"  Selected patch coordinates: px={pred_px}, py={pred_py} (out of {w_patches}x{h_patches} patches)")
    print(f"  Patch center in image coords: x={pred_px * (W_orig / w_patches):.1f}, y={pred_py * (H_orig / h_patches):.1f}")

# Debug: Verify 3D lifting by projecting back to 2D and compare with GT
if len(trajectory_pred_3d) > 0:
    print("\n3D Lifting Verification:")
    print(f"First predicted 3D point: {trajectory_pred_3d[0]}")
    print(f"  Index 0 (X): {trajectory_pred_3d[0][0]:.4f}")
    print(f"  Index 1 (Y): {trajectory_pred_3d[0][1]:.4f}")
    print(f"  Index 2 (Z): {trajectory_pred_3d[0][2]:.4f}")
    print(f"First predicted 2D point: {pred_image_coords[0]}")
    print(f"First predicted height: {heights_pred_denorm[0]:.4f}m")
    print(f"  Height matches index 2: {abs(heights_pred_denorm[0] - trajectory_pred_3d[0][2]):.6f}m")
    print(f"  Height matches index 1: {abs(heights_pred_denorm[0] - trajectory_pred_3d[0][1]):.6f}m")
    
    # Try to load GT for comparison (first prediction is for t+1, so compare with next frame)
    try:
        image_stem = image_path.stem
        if image_stem.isdigit():
            current_frame = int(image_stem)
            # First prediction is for t+1, so compare with next frame
            next_frame_str = f"{current_frame + 1:06d}"
            gt_pose_path = image_dir / f"{next_frame_str}_gripper_pose.npy"
            if gt_pose_path.exists():
                gt_pose = np.load(gt_pose_path)
                kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
                gt_kp_3d = gt_pose[:3, :3] @ kp_local + gt_pose[:3, 3]
                print(f"\nGT 3D keypoint (from next frame {next_frame_str} gripper pose): {gt_kp_3d}")
                print(f"  Index 0 (X): {gt_kp_3d[0]:.4f}")
                print(f"  Index 1 (Y): {gt_kp_3d[1]:.4f}")
                print(f"  Index 2 (Z): {gt_kp_3d[2]:.4f}")
                print(f"  GT height (index 2): {gt_kp_3d[2]:.4f}m")
                print(f"  GT height (index 1): {gt_kp_3d[1]:.4f}m")
                print(f"  3D error (all dims): {np.linalg.norm(trajectory_pred_3d[0] - gt_kp_3d):.4f}m")
                print(f"  X error: {abs(trajectory_pred_3d[0][0] - gt_kp_3d[0]):.4f}m")
                print(f"  Y error: {abs(trajectory_pred_3d[0][1] - gt_kp_3d[1]):.4f}m")
                print(f"  Z error: {abs(trajectory_pred_3d[0][2] - gt_kp_3d[2]):.4f}m")
                
                # Also verify: can we recover the GT 3D point's 2D projection?
                gt_kp_2d = project_3d_to_2d(gt_kp_3d, camera_pose, cam_K)
                if gt_kp_2d is not None:
                    print(f"\n  GT 3D -> 2D projection: {gt_kp_2d}")
                    print(f"  Predicted 2D: {pred_image_coords[0]}")
                    print(f"  2D difference: {np.linalg.norm(gt_kp_2d - pred_image_coords[0]):.2f} pixels")
                    print(f"  2D X difference: {abs(gt_kp_2d[0] - pred_image_coords[0][0]):.2f} pixels")
                    print(f"  2D Y difference: {abs(gt_kp_2d[1] - pred_image_coords[0][1]):.2f} pixels")
                    
                    # Convert GT 2D to patch coordinates for comparison
                    gt_patch_coords = rescale_coords(gt_kp_2d.reshape(1, 2), H_orig, W_orig, h_patches, w_patches)[0]
                    print(f"  GT 2D in patch coords: px={gt_patch_coords[0]:.2f}, py={gt_patch_coords[1]:.2f}")
                    print(f"  Predicted patch coords: px={pred_px}, py={pred_py}")
                    print(f"  Patch difference: dx={abs(gt_patch_coords[0] - pred_px):.2f}, dy={abs(gt_patch_coords[1] - pred_py):.2f}")
                    
                    # Try reverse: lift GT 2D projection back to 3D using GT height
                    from utils import recover_3d_from_keypoint_and_height
                    gt_height = gt_kp_3d[2]
                    gt_kp_3d_recovered = recover_3d_from_keypoint_and_height(gt_kp_2d, gt_height, camera_pose, cam_K)
                    if gt_kp_3d_recovered is not None:
                        print(f"  GT 2D + GT height -> 3D (recovered): {gt_kp_3d_recovered}")
                        print(f"  Recovery error: {np.linalg.norm(gt_kp_3d_recovered - gt_kp_3d):.6f}m")
                    
                    # Also try: lift GT 2D with PREDICTED height to see if that helps
                    pred_height = heights_pred_denorm[0]
                    gt_kp_3d_with_pred_height = recover_3d_from_keypoint_and_height(gt_kp_2d, pred_height, camera_pose, cam_K)
                    if gt_kp_3d_with_pred_height is not None:
                        print(f"  GT 2D + PREDICTED height ({pred_height:.4f}m) -> 3D: {gt_kp_3d_with_pred_height}")
                        print(f"  Error vs GT 3D: {np.linalg.norm(gt_kp_3d_with_pred_height - gt_kp_3d):.4f}m")
                    
                    # Try: lift PREDICTED 2D with GT height
                    pred_kp_3d_with_gt_height = recover_3d_from_keypoint_and_height(pred_image_coords[0], gt_height, camera_pose, cam_K)
                    if pred_kp_3d_with_gt_height is not None:
                        print(f"  PREDICTED 2D + GT height ({gt_height:.4f}m) -> 3D: {pred_kp_3d_with_gt_height}")
                        print(f"  Error vs GT 3D: {np.linalg.norm(pred_kp_3d_with_gt_height - gt_kp_3d):.4f}m")
                        print(f"  Error vs Predicted 3D: {np.linalg.norm(pred_kp_3d_with_gt_height - trajectory_pred_3d[0]):.4f}m")
            else:
                # Fallback to current frame for comparison
                frame_str = f"{current_frame:06d}"
                gt_pose_path = image_dir / f"{frame_str}_gripper_pose.npy"
                if gt_pose_path.exists():
                    gt_pose = np.load(gt_pose_path)
                    kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
                    gt_kp_3d = gt_pose[:3, :3] @ kp_local + gt_pose[:3, 3]
                    print(f"\nGT 3D keypoint (from current frame {frame_str} gripper pose - NOTE: prediction is for t+1): {gt_kp_3d}")
                    print(f"  Index 0 (X): {gt_kp_3d[0]:.4f}")
                    print(f"  Index 1 (Y): {gt_kp_3d[1]:.4f}")
                    print(f"  Index 2 (Z): {gt_kp_3d[2]:.4f}")
    except Exception as e:
        print(f"  (Could not load GT for comparison: {e})")
    
    # Project the 3D point back to 2D to verify consistency
    kp_3d_backproj = project_3d_to_2d(trajectory_pred_3d[0], camera_pose, cam_K)
    if kp_3d_backproj is not None:
        print(f"\nBack-projected 2D from 3D: {kp_3d_backproj}")
        print(f"2D error: {np.linalg.norm(kp_3d_backproj - pred_image_coords[0]):.2f} pixels")
    
    # Check camera position and ray
    from utils import unproject_2d_to_ray
    cam_pos, ray = unproject_2d_to_ray(pred_image_coords[0], camera_pose, cam_K)
    print(f"\nCamera position (robot frame): {cam_pos}")
    print(f"Ray direction (robot frame): {ray}")
    print(f"  Ray index 0 (X): {ray[0]:.4f}")
    print(f"  Ray index 1 (Y): {ray[1]:.4f}")
    print(f"  Ray index 2 (Z): {ray[2]:.4f}")
    print(f"  Using index 2 for height solve: {abs(ray[2]):.4f} (should be non-zero)")
    print(f"  Using index 1 for height solve: {abs(ray[1]):.4f}")
    
    # Debug: Check camera_pose structure
    print(f"\nCamera pose matrix (first 3 rows):")
    print(f"  Translation (camera in robot frame?): {camera_pose[:3, 3]}")
    print(f"  Rotation matrix (first row): {camera_pose[0, :3]}")
    print(f"  Rotation matrix (second row): {camera_pose[1, :3]}")
    print(f"  Rotation matrix (third row): {camera_pose[2, :3]}")
    print(f"  Camera pose inverse translation: {np.linalg.inv(camera_pose)[:3, 3]}")
    
    # Verify height solving
    if abs(ray[2]) > 1e-6:
        t = (heights_pred_denorm[0] - cam_pos[2]) / ray[2]
        reconstructed_3d = cam_pos + t * ray
        print(f"\nReconstructed 3D (using index 2): {reconstructed_3d}")
        print(f"3D error (index 2): {np.linalg.norm(reconstructed_3d - trajectory_pred_3d[0]):.6f}m")
    
    # Try using index 1 instead
    if abs(ray[1]) > 1e-6:
        t_alt = (heights_pred_denorm[0] - cam_pos[1]) / ray[1]
        reconstructed_3d_alt = cam_pos + t_alt * ray
        print(f"Reconstructed 3D (using index 1): {reconstructed_3d_alt}")
        print(f"3D error (index 1): {np.linalg.norm(reconstructed_3d_alt - trajectory_pred_3d[0]):.6f}m")

# Visualize DINO features if requested
if args.vis_dino:
    print("Visualizing DINO features...")
    pca_rgb = pca_features_patches[:, :, :3]
    pca_rgb_normalized = torch.nn.functional.sigmoid(torch.from_numpy(pca_rgb).mul(2.0)).numpy()
    pca_rgb_upsampled = TF.resize(
        torch.from_numpy(pca_rgb_normalized).permute(2, 0, 1).float(),
        (H_loaded, W_loaded),
        interpolation=TF.InterpolationMode.BILINEAR
    ).permute(1, 2, 0).numpy()
    
    img_normalized = rgb.astype(float) / 255.0
    overlay = img_normalized * 0.5 + pca_rgb_upsampled * 0.5
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(rgb)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(pca_rgb_upsampled)
    axes[1].set_title('DINO PCA (first 3 components)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay: RGB + DINO')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# 2D visualization (like vis_eval_2d.py)
if args.vis_2d:
    print("Creating 2D visualization...")
    # Resize RGB to low-res for visualization
    rgb_lowres = cv2.resize(rgb, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
    
    # Extract DINO vis
    dino_vis = dino_features[:, :, :3].numpy()
    for i in range(3):
        channel = dino_vis[:, :, i]
        min_val, max_val = channel.min(), channel.max()
        if max_val > min_val:
            dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
        else:
            dino_vis[:, :, i] = 0.5
    dino_vis = np.clip(dino_vis, 0, 1)
    dino_vis_upscaled = cv2.resize(dino_vis, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
    
    # Rescale predicted trajectory to low-res and patch coordinates
    # pred_image_coords are in H_orig x W_orig coordinates
    predicted_trajectory_lowres = rescale_coords(pred_image_coords, H_orig, W_orig, RES_LOW, RES_LOW)
    predicted_trajectory_patches = rescale_coords(pred_image_coords, H_orig, W_orig, h_patches, w_patches)
    
    # Rescale current EEF 2D position to low-res for visualization
    # current_kp_2d_image is in H_orig x W_orig coordinates
    current_kp_2d_lowres = rescale_coords(current_kp_2d_image.reshape(1, 2), H_orig, W_orig, RES_LOW, RES_LOW)[0]
    
    # Create visualization
    fig = plt.figure(figsize=(24, 10))
    gs = GridSpec(2, 4, figure=fig, hspace=0.3, wspace=0.3, height_ratios=[1, 0.4])
    ax_rgb = fig.add_subplot(gs[0, 0])
    ax_dino = fig.add_subplot(gs[0, 1])
    ax_attn = fig.add_subplot(gs[0, 2])
    ax_traj = fig.add_subplot(gs[0, 3])
    ax_height = fig.add_subplot(gs[1, :])
    
    # RGB with trajectory
    ax_rgb.imshow(rgb_lowres)
    if len(predicted_trajectory_lowres) > 0:
        ax_rgb.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
        for i, (x, y) in enumerate(predicted_trajectory_lowres):
            color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
            ax_rgb.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    ax_rgb.plot(current_kp_2d_lowres[0], current_kp_2d_lowres[1], 
               'go', markersize=8, markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax_rgb.set_title("RGB with Predicted Trajectory")
    ax_rgb.legend(loc='upper right', fontsize=9)
    ax_rgb.axis('off')
    
    # DINO with trajectory
    ax_dino.imshow(dino_vis_upscaled)
    if len(predicted_trajectory_patches) > 0:
        pred_patches_lowres = predicted_trajectory_patches * (RES_LOW / np.array([w_patches, h_patches]))
        ax_dino.plot(pred_patches_lowres[:, 0], pred_patches_lowres[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
        for i, (x, y) in enumerate(pred_patches_lowres):
            color = plt.cm.plasma(i / len(pred_patches_lowres))
            ax_dino.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
    ax_dino.plot(current_kp_2d_patches[0] * (RES_LOW / w_patches), current_kp_2d_patches[1] * (RES_LOW / h_patches),
               'go', markersize=8, markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
    ax_dino.set_title("DINO Features with Predicted Trajectory")
    ax_dino.legend(loc='upper right', fontsize=9)
    ax_dino.axis('off')
    
    # Attention map (first timestep)
    attention_map = attention_scores[0].reshape(h_patches, w_patches)
    attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
    attention_map_upscaled = cv2.resize(attention_map_norm, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
    ax_attn.imshow(attention_map_upscaled, cmap='hot', vmin=0, vmax=1)
    ax_attn.set_title("Attention Map (t=0)")
    ax_attn.axis('off')
    
    # Trajectory overview
    ax_traj.imshow(rgb_lowres)
    if len(predicted_trajectory_lowres) > 0:
        ax_traj.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 'r-', linewidth=3, alpha=0.8, label='Predicted')
        for i, (x, y) in enumerate(predicted_trajectory_lowres):
            color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
            ax_traj.plot(x, y, 'o', color=color, markersize=8, markeredgecolor='white', markeredgewidth=1)
    ax_traj.plot(current_kp_2d_lowres[0], current_kp_2d_lowres[1],
               'go', markersize=10, markeredgecolor='white', markeredgewidth=2, label='Start EEF', zorder=10)
    ax_traj.set_title("Full Predicted Trajectory")
    ax_traj.legend(loc='upper right', fontsize=9)
    ax_traj.axis('off')
    
    # Height chart
    timesteps = np.arange(1, len(heights_pred_denorm) + 1)
    ax_height.bar(timesteps - 0.2, heights_pred_denorm, width=0.4, alpha=0.6, color='red', label='Pred Height')
    ax_height.set_xlabel('Timestep', fontsize=10)
    ax_height.set_ylabel('Height (m)', fontsize=10)
    ax_height.set_title('Predicted Height Trajectory', fontsize=12)
    ax_height.legend(fontsize=9)
    ax_height.grid(alpha=0.3)
    
    plt.tight_layout()
    output_path = Path(f'groundplane_testing/test_scripts/image_inference_2d_{image_path.stem}.png')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved 2D visualization to {output_path}")
    plt.show()

# MuJoCo visualization (like test_ik.py)
if args.vis_mujoco:
    print("Creating MuJoCo visualization...")
    import xml.etree.ElementTree as ET
    import mink
    from exo_utils import get_link_poses_from_robot
    from utils import ik_to_keypoint_and_rotation
    
    # Load GT trajectory 3D keypoints if available (for comparison)
    trajectory_gt_3d = []
    try:
        image_stem = image_path.stem
        if image_stem.isdigit():
            current_frame = int(image_stem)
            # Load GT trajectory from subsequent frames
            for offset in range(1, min(MAX_TIMESTEPS + 1, 100)):  # Limit to reasonable number
                frame_str = f"{current_frame + offset:06d}"
                gt_pose_path = image_dir / f"{frame_str}_gripper_pose.npy"
                if gt_pose_path.exists():
                    gt_pose = np.load(gt_pose_path)
                    kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
                    gt_kp_3d = gt_pose[:3, :3] @ kp_local + gt_pose[:3, 3]
                    trajectory_gt_3d.append(gt_kp_3d)
                else:
                    break
    except Exception as e:
        print(f"  (Could not load GT trajectory: {e})")
    
    # Add predicted and GT trajectory spheres to XML
    xml_root = ET.fromstring(robot_config.xml)
    worldbody = xml_root.find('worldbody')
    
    # GT trajectory (green to yellow) - if available
    if len(trajectory_gt_3d) > 0:
        for i, kp_pos in enumerate(trajectory_gt_3d):
            green = 1.0 - (i / max(len(trajectory_gt_3d) - 1, 1)) * 0.5
            red = i / max(len(trajectory_gt_3d) - 1, 1) * 0.5
            ET.SubElement(worldbody, 'site', {
                'name': f'gt_kp_{i}', 'type': 'sphere', 'size': '0.015',
                'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
            })
        print(f"  Added {len(trajectory_gt_3d)} GT trajectory spheres")
    
    # Predicted trajectory (red to yellow)
    for i, kp_pos in enumerate(trajectory_pred_3d):
        red = 1.0 - (i / max(len(trajectory_pred_3d) - 1, 1)) * 0.5
        green = i / max(len(trajectory_pred_3d) - 1, 1) * 0.5
        ET.SubElement(worldbody, 'site', {
            'name': f'pred_kp_{i}', 'type': 'sphere', 'size': '0.015',
            'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
        })
    print(f"  Added {len(trajectory_pred_3d)} predicted trajectory spheres")
    
    # Debug: Print first few predicted vs GT 3D positions for comparison
    if len(trajectory_pred_3d) > 0 and len(trajectory_gt_3d) > 0:
        print(f"\n  First few predicted vs GT 3D positions:")
        for i in range(min(3, len(trajectory_pred_3d), len(trajectory_gt_3d))):
            pred_pos = trajectory_pred_3d[i]
            gt_pos = trajectory_gt_3d[i]
            error = np.linalg.norm(pred_pos - gt_pos)
            print(f"    t={i+1}: Pred=[{pred_pos[0]:.4f}, {pred_pos[1]:.4f}, {pred_pos[2]:.4f}], "
                  f"GT=[{gt_pos[0]:.4f}, {gt_pos[1]:.4f}, {gt_pos[2]:.4f}], Error={error:.4f}m")
    
    mj_model_viz = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
    mj_data_viz = mujoco.MjData(mj_model_viz)
    
    # Try to load saved joint state if available (like we do for gripper pose)
    try:
        image_stem = image_path.stem
        if image_stem.isdigit():
            frame_str = f"{int(image_stem):06d}"
            joint_state_path = image_dir / f"{frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                print(f"Using saved joint state from {joint_state_path.name}")
                mj_data_viz.qpos[:] = joint_state
                mj_data_viz.ctrl[:] = joint_state[:len(mj_data_viz.ctrl)]
            else:
                # Fallback to first frame joint state
                first_frame_joint_state = image_dir / "000000.npy"
                if first_frame_joint_state.exists():
                    joint_state = np.load(first_frame_joint_state)
                    print(f"Using saved joint state from first frame: {first_frame_joint_state.name}")
                    mj_data_viz.qpos[:] = joint_state
                    mj_data_viz.ctrl[:] = joint_state[:len(mj_data_viz.ctrl)]
                else:
                    # Fallback to estimated state
                    print("No saved joint state found, using estimated state")
                    mj_data_viz.qpos[:] = mj_data.qpos
                    mj_data_viz.ctrl[:] = mj_data.ctrl
        else:
            raise FileNotFoundError("Image stem is not a digit")
    except Exception as e:
        # Fallback to estimated state
        print(f"Could not load saved joint state ({e}), using estimated state")
        mj_data_viz.qpos[:] = mj_data.qpos
        mj_data_viz.ctrl[:] = mj_data.ctrl
    
    mujoco.mj_forward(mj_model_viz, mj_data_viz)
    position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses)
    mujoco.mj_forward(mj_model_viz, mj_data_viz)
    
    # Setup IK configuration
    ik_configuration = mink.Configuration(mj_model_viz)
    ik_configuration.update(mj_data_viz.qpos)
    
    # Hardcoded median rotation (like baseline_training/test_scripts/test_ik.py)
    median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                        [-0.04176828,  0.67620482,  0.73552869],
                                        [-0.00244771,  0.73609967, -0.67686874]])
    
    # Render from camera pose
    # cam_K is already calibrated for H_loaded, W_loaded, so use it directly
    print("Rendering robot through predicted trajectory from camera pose")
    
    for i, target_kp_pos in enumerate(trajectory_pred_3d):
        target_gripper_rot = median_dataset_rotation
        
        # Update IK configuration from current state before solving
        link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
        position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
        mujoco.mj_forward(mj_model_viz, mj_data_viz)
        ik_configuration.update(mj_data_viz.qpos)
        
        # Solve IK with median rotation constraint (like baseline_training/test_scripts/test_ik.py)
        ik_to_keypoint_and_rotation(target_kp_pos, target_gripper_rot, ik_configuration, robot_config, mj_model_viz, mj_data_viz)
        
        # Update link poses and forward kinematics after IK
        link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
        position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
        mujoco.mj_forward(mj_model_viz, mj_data_viz)
        
        # Verify IK result
        kp_body_id = mujoco.mj_name2id(mj_model_viz, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        achieved_kp_pos = mj_data_viz.xpos[kp_body_id].copy()
        ik_error = np.linalg.norm(achieved_kp_pos - target_kp_pos)
        if i < 3:  # Print first few for debugging
            print(f"  t={i+1}: Target=[{target_kp_pos[0]:.4f}, {target_kp_pos[1]:.4f}, {target_kp_pos[2]:.4f}], "
                  f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
                  f"IK Error={ik_error:.4f}m")
        
        # Render at original resolution with cam_K (calibrated for H_orig x W_orig)
        rendered = render_from_camera_pose(mj_model_viz, mj_data_viz, camera_pose, cam_K, H_orig, W_orig)
        # Resize rendered to match loaded image size for display
        rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        rgb_resized = rgb if rgb.shape[:2] == (H_loaded, W_loaded) else cv2.resize(rgb, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        
        # Display results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb_resized, rendered_resized, (rgb_resized * 0.5 + rendered_resized * 0.5).astype(np.uint8)]):
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Original RGB')
        axes[1].set_title(f'Rendered Robot (t={i+1})')
        axes[2].set_title('Overlay')
        plt.tight_layout()
        plt.show()

print("✓ Inference complete!")
