"""eval.py — Evaluate a trained PARA checkpoint in the LIBERO simulation environment.

PARA predicts the next N_WINDOW absolute EEF 3D positions from a single RGB image.
This script runs closed-loop rollouts: at each env step, re-run the model on the
current observation, decode the first predicted position into a delta OSC_POSE action,
and step the sim. Success is the LIBERO binary predicate check (env returns done=True).

Usage:
    python libero/eval.py \
        --checkpoint libero/checkpoints/para_libero_spatial_t0/best.pth \
        --benchmark libero_spatial \
        --task_id 0 \
        --n_episodes 20

Action format: 7D OSC_POSE [delta_pos (3), delta_rot (3)=0, gripper (1)]
"""

import argparse
import json
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

sys.path.insert(0, os.path.dirname(__file__))

import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height


def _reposition_camera(sim, camera_name, theta_deg, phi_deg):
    """Reposition camera on a spherical cap around its default position.

    Matches the viewpoint generation script's camera positioning logic.
    theta=0, phi=0 means default camera position (no change).
    """
    from scipy.spatial.transform import Rotation as ScipyR

    cam_id = sim.model.camera_name2id(camera_name)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    forward = -cam_xmat[:, 2]

    # Compute look-at point (where camera points on table surface)
    TABLE_Z = 0.90
    t_hit = (TABLE_Z - default_pos[2]) / (forward[2] + 1e-8)
    look_at = default_pos + t_hit * forward

    # Spherical cap offset
    radius = np.linalg.norm(default_pos - look_at)
    default_dir = (default_pos - look_at) / radius
    up = np.array([0, 0, 1.0])
    if abs(np.dot(default_dir, up)) > 0.99:
        up = np.array([1, 0, 0.0])
    right = np.cross(default_dir, up)
    right /= np.linalg.norm(right)
    true_up = np.cross(right, default_dir)

    theta = np.radians(theta_deg)
    phi = np.radians(phi_deg)
    offset = (np.sin(theta) * np.cos(phi) * right +
              np.sin(theta) * np.sin(phi) * true_up +
              np.cos(theta) * default_dir)
    new_pos = look_at + radius * offset

    # Compute look-at quaternion (MuJoCo convention)
    fwd = look_at - new_pos
    fwd = fwd / (np.linalg.norm(fwd) + 1e-12)
    cam_z = -fwd
    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(fwd, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])
    cam_x = np.cross(up_hint, cam_z)
    cam_x = cam_x / (np.linalg.norm(cam_x) + 1e-12)
    cam_y = np.cross(cam_z, cam_x)
    R = np.stack([cam_x, cam_y, cam_z], axis=-1)
    q = ScipyR.from_matrix(R).as_quat()  # (x,y,z,w)
    new_quat = np.array([q[3], q[0], q[1], q[2]])  # (w,x,y,z) for MuJoCo

    sim.model.cam_pos[cam_id] = new_pos
    sim.model.cam_quat[cam_id] = new_quat
    sim.forward()


def get_model_class(model_type):
    if model_type == "para":
        return TrajectoryHeatmapPredictor
    elif model_type == "act":
        from model_act import ACTPredictor
        return ACTPredictor
    elif model_type == "da3":
        from model_da3 import DA3Predictor
        return DA3Predictor
    elif model_type == "moge":
        from model_moge import MoGePredictor
        return MoGePredictor
    elif model_type == "dino_vla":
        from model_dino_vla import DinoVLAPredictor
        return DinoVLAPredictor
    elif model_type == "internvl":
        from model_vla_internvl import InternVLAPredictor
        return InternVLAPredictor
    elif model_type == "internvl_act":
        from model_vla_internvl_act import InternVLACTPredictor
        return InternVLACTPredictor
    elif model_type == "dual_da3":
        from model_dual_da3 import DualDA3Predictor
        return DualDA3Predictor
    elif model_type == "dual_para":
        from model_dual_para import DualParaPredictor
        return DualParaPredictor
    elif model_type == "cost_volume":
        from model_cost_volume import CostVolumePredictor
        return CostVolumePredictor
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
IMAGE_SIZE = 448


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def preprocess_obs(rgb_obs, image_size=IMAGE_SIZE):
    """HxWx3 uint8 → (1, 3, H, W) float tensor, ImageNet-normalized.

    LIBERO obs images are already upright (already flipped vs raw render),
    so we flipud to match training convention (flipud(obs) → training image).
    """
    img = rgb_obs.astype(np.float32) / 255.0
    img = np.flipud(img).copy()                          # match training image convention
    img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0)  # (1, 3, H, W)
    return tensor


def get_camera_params(sim, camera_name, image_size=IMAGE_SIZE):
    """Return camera matrices needed for projection and 3D recovery.

    - world_to_camera: (4,4) world→camera transform  → for project_points_from_world_to_camera
    - camera_pose:     (4,4) camera→world transform   → for recover_3d (ray unprojection)
    - cam_K:           (3,3) intrinsic at image_size  → for recover_3d
    These are two different matrices; using the wrong one for 3D recovery gives bad targets.
    """
    world_to_camera = get_camera_transform_matrix(sim, camera_name, image_size, image_size)
    camera_pose     = get_camera_extrinsic_matrix(sim, camera_name)   # camera→world
    cam_K_norm = get_camera_intrinsic_matrix(sim, camera_name, image_size, image_size)
    cam_K_norm[0] /= image_size
    cam_K_norm[1] /= image_size
    cam_K = cam_K_norm.copy()
    cam_K[0] *= image_size
    cam_K[1] *= image_size
    return world_to_camera, camera_pose, cam_K


def eef_to_start_kp(eef_pos, world_to_camera, image_size=IMAGE_SIZE):
    """Project current EEF world position → (u, v) pixel in training image convention."""
    pix_rc = project_points_from_world_to_camera(
        points=eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera_transform=world_to_camera,
        camera_height=image_size,
        camera_width=image_size,
    )[0]
    u = float(pix_rc[1])  # col
    v = float(pix_rc[0])  # row (same convention as training)
    return torch.tensor([u, v], dtype=torch.float32)


def render_eval_frame(rgb_obs, volume_logits, current_eef_pos,
                      world_to_camera, cam_K, image_size=IMAGE_SIZE, step_idx=0, success=None):
    """Render a single eval step: RGB + heatmap overlay + predicted pixel + GT EEF dot."""
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    # Preprocess frame for display (flipud to match training convention)
    frame = rgb_obs.astype(np.float32) / 255.0
    frame = np.flipud(frame).copy()
    frame = cv2.resize(frame, (image_size, image_size), interpolation=cv2.INTER_LINEAR)

    # Heatmap from first predicted timestep
    vol_t = volume_logits[0, 0]                                      # (Nh, pred_size, pred_size)
    vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
    heat_small = vol_probs.max(dim=0)[0].cpu().numpy()               # (pred_size, pred_size)
    heat = cv2.resize(heat_small, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    heat = (heat - heat.min()) / (heat.max() + 1e-8)
    heat_rgb = np.zeros_like(frame)
    heat_rgb[..., 0] = heat
    overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
    vis = (overlay * 255.0).astype(np.uint8)

    # Predicted pixel (green crosshair)
    flat_idx = heat_small.argmax()
    py, px = flat_idx // pred_size, flat_idx % pred_size
    px_full = int((px + 0.5) * scale)
    py_full = int((py + 0.5) * scale)
    cv2.drawMarker(vis, (px_full, py_full), (0, 255, 0), cv2.MARKER_CROSS, 18, 2, cv2.LINE_AA)

    # GT EEF projection (white dot)
    pix_rc = project_points_from_world_to_camera(
        current_eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera, image_size, image_size,
    )[0]
    u, v = int(round(float(pix_rc[1]))), int(round(float(pix_rc[0])))
    if 0 <= u < image_size and 0 <= v < image_size:
        cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)

    # Step counter + success indicator
    label = f"step {step_idx}"
    if success is not None:
        label += "  SUCCESS" if success else "  running"
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)

    return vis  # (H, W, 3) uint8 RGB


def render_window_strip(rgb_obs, volume_logits, current_eef_pos,
                        world_to_camera, cam_K, image_size=IMAGE_SIZE, step_idx=0):
    """Render a horizontal strip showing heatmaps for all N_WINDOW predicted timesteps.

    Each tile: RGB + heatmap overlay (red) + predicted pixel (green cross) + GT EEF (white dot).
    Returns a single wide image: (image_size, image_size * n_window, 3) uint8 RGB.
    """
    n_window = volume_logits.shape[1]
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    # Preprocess frame for display (flipud to match training convention)
    frame = rgb_obs.astype(np.float32) / 255.0
    frame = np.flipud(frame).copy()
    frame = cv2.resize(frame, (image_size, image_size), interpolation=cv2.INTER_LINEAR)

    # GT EEF pixel (same for all timesteps — current position)
    pix_rc = project_points_from_world_to_camera(
        current_eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera, image_size, image_size,
    )[0]
    eef_u, eef_v = int(round(float(pix_rc[1]))), int(round(float(pix_rc[0])))

    tiles = []
    for t in range(n_window):
        vol_t = volume_logits[0, t]  # (Nh, pred_size, pred_size)
        vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
        heat_small = vol_probs.max(dim=0)[0].cpu().numpy()
        heat = cv2.resize(heat_small, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
        heat = (heat - heat.min()) / (heat.max() + 1e-8)

        heat_rgb = np.zeros_like(frame)
        heat_rgb[..., 0] = heat
        overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
        vis = (overlay * 255.0).astype(np.uint8)

        # Predicted pixel (green crosshair)
        flat_idx = heat_small.argmax()
        py, px = flat_idx // pred_size, flat_idx % pred_size
        px_full = int((px + 0.5) * scale)
        py_full = int((py + 0.5) * scale)
        cv2.drawMarker(vis, (px_full, py_full), (0, 255, 0), cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)

        # GT EEF (white dot)
        if 0 <= eef_u < image_size and 0 <= eef_v < image_size:
            cv2.circle(vis, (eef_u, eef_v), 5, (255, 255, 255), -1)

        # Timestep label
        label = f"step {step_idx} t+{t}"
        cv2.putText(vis, label, (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(vis, label, (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (20, 20, 20), 1, cv2.LINE_AA)
        tiles.append(vis)

    return np.concatenate(tiles, axis=1)  # (H, W*n_window, 3)


def decode_window_actions(volume_logits, model, feats, camera_pose, cam_K,
                          current_eef_pos, current_eef_quat, image_size=IMAGE_SIZE, max_delta=0.05):
    """Decode all N_WINDOW predicted timesteps into OSC_POSE delta actions.

    Gripper/rotation are predicted by indexing features at the argmax pixel of each
    timestep and passing through the model's MLP heads (same as training inference path).

    Args:
        volume_logits:   (1, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
        model:           TrajectoryHeatmapPredictor (for predict_at_pixels)
        feats:           (1, D, pred_size, pred_size) feature map from forward()
        camera_pose:     (4,4) camera→world extrinsic  ← get_camera_extrinsic_matrix()
        cam_K:           (3,3) intrinsic at image_size
        current_eef_pos: (3,) numpy EEF position at start of window
        max_delta:       max position delta magnitude in metres before OSC normalisation

    Returns:
        actions:         list of N_WINDOW (7,) numpy [delta_pos(3), delta_rot_axisangle(3), gripper(1)]
        pred_3d_targets: list of N_WINDOW (3,) absolute EEF targets (for debug)
    """
    from scipy.spatial.transform import Rotation as ScipyR
    OSC_POS_SCALE = 0.05  # robosuite OSC_POSE: input [-1,1] → output [-0.05, 0.05] m
    OSC_ROT_SCALE = 0.5   # robosuite OSC_POSE: input [-1,1] → output [-0.5, 0.5] rad

    n_window  = volume_logits.shape[1]
    pred_size = volume_logits.shape[-1]
    scale     = image_size / pred_size
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    min_g, max_g = model_module.MIN_GRIPPER, model_module.MAX_GRIPPER
    min_r = np.array(model_module.MIN_ROT, dtype=np.float64)
    max_r = np.array(model_module.MAX_ROT, dtype=np.float64)

    # --- Pass 1: collect predicted pixel + height bin for each timestep ---
    pred_px_list = []  # list of (px, py) in pred_size space
    pred_hbin_list = []
    for t in range(n_window):
        vol_t      = volume_logits[0, t]          # (Nh, pred_size, pred_size)
        max_over_h = vol_t.max(dim=0)[0]
        flat_idx   = max_over_h.reshape(-1).argmax().item()
        py = flat_idx // pred_size
        px = flat_idx % pred_size
        h_bin = vol_t[:, py, px].argmax().item()
        pred_px_list.append((px, py))
        pred_hbin_list.append(h_bin)

    # --- Batch gripper/rotation prediction at all predicted pixels ---
    pred_pixels = torch.tensor(
        [[px, py] for px, py in pred_px_list], dtype=torch.float32, device=feats.device
    ).unsqueeze(0)  # (1, N_WINDOW, 2)
    pred_hbins = torch.tensor(pred_hbin_list, dtype=torch.long, device=feats.device).unsqueeze(0)  # (1, N_WINDOW)
    with torch.no_grad():
        if hasattr(model, 'gripper_mlp') and model.model_type == "cost_volume":
            gripper_logits, rotation_logits = model.predict_at_pixels(feats, pred_pixels, pred_hbins)
        else:
            gripper_logits, rotation_logits = model.predict_at_pixels(feats, pred_pixels)
    # gripper_logits:  (1, N_WINDOW) raw logits
    # rotation_logits: (1, N_WINDOW, 3, N_ROT_BINS)

    # --- Pass 2: decode actions ---
    actions         = []
    pred_3d_targets = []
    ref_pos         = current_eef_pos.copy()

    for t, (px, py) in enumerate(pred_px_list):
        vol_t   = volume_logits[0, t]  # (Nh, pred_size, pred_size)
        px_full = (px + 0.5) * scale
        py_full = (py + 0.5) * scale

        # Height → 3D position
        h_bin  = vol_t[:, py, px].argmax().item()
        height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        pred_3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), height, camera_pose, cam_K,
        )
        if pred_3d is None:
            pred_3d = pred_3d_targets[-1] if pred_3d_targets else ref_pos.copy()
        pred_3d_targets.append(pred_3d)

        # Delta position → normalised OSC input
        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Rotation decode
        ref_quat = np.array(model_module.REF_ROTATION_QUAT, dtype=np.float64)
        R_ref = ScipyR.from_quat(ref_quat)
        if rotation_logits is None:
            delta_rot_norm = np.zeros(3, dtype=np.float64)
        else:
            if rotation_logits.dim() == 3:
                rot_sigmoid = rotation_logits[0, t].cpu().numpy().astype(np.float64)
                delta_rotvec_pred = rot_sigmoid * (max_r - min_r) + min_r
            else:
                delta_rotvec_pred = np.array([
                    (rotation_logits[0, t, axis, :].argmax().item() / max(N_ROT_BINS - 1, 1))
                    * (max_r[axis] - min_r[axis]) + min_r[axis]
                    for axis in range(3)
                ])
            R_pred = R_ref * ScipyR.from_rotvec(delta_rotvec_pred)
            R_current = ScipyR.from_quat(current_eef_quat)
            R_delta = R_pred * R_current.inv()
            delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        # Gripper: 2-class argmax → class 0 (open, -1) or class 1 (close, +1)
        g_class = int(gripper_logits[0, t].argmax().item())
        gripper_cmd = 1.0 if g_class == 1 else -1.0

        action      = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot_norm
        action[6]   = gripper_cmd
        actions.append(action)

    return actions, pred_3d_targets


def decode_dual_window_actions(agent_vol, wrist_vol, model, agent_feats, wrist_feats,
                                agent_cam_pose, agent_cam_K, wrist_cam_pose, wrist_cam_K,
                                wrist_w2c,
                                current_eef_pos, current_eef_quat, image_size=IMAGE_SIZE, max_delta=0.05):
    """Decode dual-camera predictions with agentview-guided wrist selection.

    Strategy: use agentview for coarse 3D prediction, then check if that 3D point
    projects into the wrist camera frustum. If yes, use wrist view's prediction
    (more precise at close range). If no, fall back to agentview.
    """
    from scipy.spatial.transform import Rotation as ScipyR
    OSC_POS_SCALE = 0.05
    OSC_ROT_SCALE = 0.5

    n_window  = agent_vol.shape[1]
    pred_size = agent_vol.shape[-1]
    scale     = image_size / pred_size
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    min_r = np.array(model_module.MIN_ROT, dtype=np.float64)
    max_r = np.array(model_module.MAX_ROT, dtype=np.float64)

    actions = []
    pred_3d_targets = []
    ref_pos = current_eef_pos.copy()

    for t in range(n_window):
        a_vol_t = agent_vol[0, t]  # (Nh, H, W)
        w_vol_t = wrist_vol[0, t]

        # Step 1: get coarse 3D from agentview
        a_max_over_h = a_vol_t.max(dim=0)[0]
        a_flat = a_max_over_h.reshape(-1).argmax().item()
        a_py = a_flat // pred_size
        a_px = a_flat % pred_size
        a_px_full = (a_px + 0.5) * scale
        a_py_full = (a_py + 0.5) * scale
        a_h_bin = a_vol_t[:, a_py, a_px].argmax().item()
        a_height = (a_h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        agent_3d = recover_3d_from_direct_keypoint_and_height(
            np.array([a_px_full, a_py_full], dtype=np.float64), a_height, agent_cam_pose, agent_cam_K
        )

        # Step 2: project agentview's 3D prediction onto wrist camera
        use_wrist = False
        if agent_3d is not None:
            wrist_pix_rc = project_points_from_world_to_camera(
                agent_3d.reshape(1, 3), wrist_w2c, image_size, image_size
            )[0]
            wu, wv = float(wrist_pix_rc[1]), float(wrist_pix_rc[0])
            margin = 20
            if margin <= wu < image_size - margin and margin <= wv < image_size - margin:
                use_wrist = True

        # Step 3: decode from chosen view
        if use_wrist:
            vol_t = w_vol_t
            cam_pose = wrist_cam_pose
            cam_K = wrist_cam_K
            view_name = 'wrist'
            feats = wrist_feats
        else:
            vol_t = a_vol_t
            cam_pose = agent_cam_pose
            cam_K = agent_cam_K
            view_name = 'agent'
            feats = agent_feats

        # 2D prediction from chosen view
        max_over_h = vol_t.max(dim=0)[0]
        flat_idx = max_over_h.reshape(-1).argmax().item()
        py = flat_idx // pred_size
        px = flat_idx % pred_size
        px_full = (px + 0.5) * scale
        py_full = (py + 0.5) * scale

        # Height → 3D
        h_bin = vol_t[:, py, px].argmax().item()
        height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        pred_3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), height, cam_pose, cam_K
        )
        if pred_3d is None:
            pred_3d = agent_3d if agent_3d is not None else (pred_3d_targets[-1] if pred_3d_targets else ref_pos.copy())
        pred_3d_targets.append(pred_3d)

        # Delta position
        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Rotation + gripper from chosen view
        pred_pixels = torch.tensor([[px, py]], dtype=torch.float32, device=feats.device).unsqueeze(0)
        with torch.no_grad():
            grip_logits, rot_logits = model.predict_at_pixels(feats, pred_pixels, view_name=view_name)

        euler_pred = np.array([
            (rot_logits[0, 0, axis, :].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (max_r[axis] - min_r[axis]) + min_r[axis]
            for axis in range(3)
        ])
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_current = ScipyR.from_quat(current_eef_quat)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        g_class = int(grip_logits[0, 0].argmax().item())
        gripper_cmd = 1.0 if g_class == 1 else -1.0

        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_norm
        action[3:6] = delta_rot_norm
        action[6] = gripper_cmd
        actions.append(action)

    return actions, pred_3d_targets


def decode_act_actions(pos_pred, rot_pred, gripper_pred, current_eef_pos, current_eef_quat):
    """Decode ACT normalized [0,1] predictions into OSC_POSE delta actions.

    Model outputs are in [0,1] (sigmoid). We denormalize using dataset min/max,
    then compute deltas for the OSC_POSE controller.

    Args:
        pos_pred:     (1, N_WINDOW, 3) normalized [0,1] positions (tensor)
        rot_pred:     (1, N_WINDOW, 3) normalized [0,1] rotations (tensor)
        gripper_pred: (1, N_WINDOW) normalized [0,1] gripper (tensor)
        current_eef_pos:  (3,) numpy
        current_eef_quat: (4,) numpy

    Returns:
        actions: list of N_WINDOW (7,) numpy [delta_pos(3), delta_rot(3), gripper(1)]
    """
    from scipy.spatial.transform import Rotation as ScipyR
    OSC_POS_SCALE = 0.05
    OSC_ROT_SCALE = 0.5

    min_pos = np.array(model_module.MIN_POS, dtype=np.float64)
    max_pos = np.array(model_module.MAX_POS, dtype=np.float64)
    min_rot = np.array(model_module.MIN_ROT, dtype=np.float64)
    max_rot = np.array(model_module.MAX_ROT, dtype=np.float64)
    min_g = float(model_module.MIN_GRIPPER)
    max_g = float(model_module.MAX_GRIPPER)

    n_window = pos_pred.shape[1]
    actions  = []
    ref_pos  = current_eef_pos.copy()

    for t in range(n_window):
        # Denormalize from [0,1] to original scale
        pos_norm = pos_pred[0, t].cpu().numpy().astype(np.float64)
        pred_3d = pos_norm * (max_pos - min_pos) + min_pos

        rot_norm = rot_pred[0, t].cpu().numpy().astype(np.float64)
        euler_pred = rot_norm * (max_rot - min_rot) + min_rot

        g_norm = float(gripper_pred[0, t].cpu())
        gripper_val = g_norm * (max_g - min_g) + min_g

        # Delta position
        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > 0.05:
            delta_pos = delta_pos / norm * 0.05
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Delta rotation
        R_pred     = ScipyR.from_euler('xyz', euler_pred)
        R_current  = ScipyR.from_quat(current_eef_quat)
        R_delta    = R_pred * R_current.inv()
        delta_rot  = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        # Gripper: raw logit → sigmoid → threshold → ±1 command
        g_logit = float(gripper_pred[0, t].cpu())
        gripper_cmd = 1.0 if g_logit > 0.0 else -1.0  # logit > 0 means P(close) > 0.5

        # Move first (keep previous gripper state), then apply gripper
        action      = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot
        action[6]   = gripper_cmd
        actions.append(action)

    return actions


# ---------------------------------------------------------------------------
# Main eval loop
# ---------------------------------------------------------------------------

def run_eval(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps"  if torch.backends.mps.is_available() else
        "cpu"
    )
    print(f"Device: {device}")

    # --- Load checkpoint ---
    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu")

    # Restore height/gripper range from checkpoint
    model_module.MIN_HEIGHT  = float(ckpt.get("min_height",  model_module.MIN_HEIGHT))
    model_module.MAX_HEIGHT  = float(ckpt.get("max_height",  model_module.MAX_HEIGHT))
    model_module.MIN_GRIPPER = float(ckpt.get("min_gripper", model_module.MIN_GRIPPER))
    model_module.MAX_GRIPPER = float(ckpt.get("max_gripper", model_module.MAX_GRIPPER))
    if "min_rot" in ckpt:
        model_module.MIN_ROT = ckpt["min_rot"]
        model_module.MAX_ROT = ckpt["max_rot"]
    if "min_pos" in ckpt:
        model_module.MIN_POS = ckpt["min_pos"]
        model_module.MAX_POS = ckpt["max_pos"]
    if "ref_rotation_quat" in ckpt:
        model_module.REF_ROTATION_QUAT = ckpt["ref_rotation_quat"]
    print(f"Height  range: [{model_module.MIN_HEIGHT:.4f}, {model_module.MAX_HEIGHT:.4f}]")
    print(f"Gripper range: [{model_module.MIN_GRIPPER:.4f}, {model_module.MAX_GRIPPER:.4f}]")
    print(f"Rot     range: {[f'{v:.3f}' for v in model_module.MIN_ROT]} .. {[f'{v:.3f}' for v in model_module.MAX_ROT]}")
    print(f"Ref rot:       {[f'{v:.4f}' for v in model_module.REF_ROTATION_QUAT]}")
    print(f"Pos     range: {[f'{v:.3f}' for v in model_module.MIN_POS]} .. {[f'{v:.3f}' for v in model_module.MAX_POS]}")

    ModelClass = get_model_class(args.model_type)
    if args.model_type == "para":
        model = ModelClass(target_size=IMAGE_SIZE, pred_size=PRED_SIZE)
    elif args.model_type in ("internvl", "internvl_act"):
        model = ModelClass(target_size=IMAGE_SIZE, model_name=args.model_name)
    else:
        model = ModelClass(target_size=IMAGE_SIZE)
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model = model.to(device)
    model.eval()
    print(f"Loaded model ({args.model_type}) from {ckpt_path}")

    # --- LIBERO benchmark + task ---
    bench     = bm.get_benchmark_dict()[args.benchmark]()
    task      = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: [{args.benchmark}] {task_name}")

    # Load initial states from demo HDF5 (frame 0 of each demo)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]  # first frame per demo
    n_episodes = min(args.n_episodes, len(init_states))
    print(f"Running {n_episodes} / {len(init_states)} episodes...")

    # --- Build env (default OSC_POSE controller) ---
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=[args.camera] + (["robot0_eye_in_hand"] if args.model_type in ("dual_da3", "dual_para", "cost_volume") else []),
    )
    env.seed(args.seed)
    env.reset()

    # Clean scene: remove distractors and furniture (matches OOD objpos training data)
    if args.clean_scene:
        sim = env.env.sim
        # Hide furniture underground
        for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
            try:
                bid = sim.model.body_name2id(fname)
                sim.model.body_pos[bid] = np.array([0, 0, -5.0])
            except Exception:
                pass
        sim.forward()
        # Hide distractors (make invisible)
        distractor_names = ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]
        distractor_bodies = set()
        for dname in distractor_names:
            try:
                distractor_bodies.add(sim.model.body_name2id(dname))
            except Exception:
                pass
        for geom_id in range(sim.model.ngeom):
            if sim.model.geom_bodyid[geom_id] in distractor_bodies:
                sim.model.geom_rgba[geom_id][3] = 0.0
        print("✓ Clean scene: distractors hidden, furniture removed")

    print("Environment ready.")

    # Load CLIP embedding for models with task conditioning
    clip_embedding = None
    if args.model_type in ("dino_vla", "act"):
        clip_path = os.path.join(args.clip_embeddings_dir, args.benchmark, f"task_{args.task_id}_clip.pt")
        if os.path.exists(clip_path):
            clip_embedding = torch.load(clip_path, map_location=device).unsqueeze(0)  # (1, D_clip)
            print(f"Loaded CLIP embedding: {clip_path}")
        else:
            print(f"WARNING: CLIP embedding not found at {clip_path}, using zeros")
            clip_embedding = torch.zeros(1, 512, device=device)

    # Load task description for internvl VLA
    task_text_for_eval = None
    if args.model_type in ("internvl", "internvl_act"):
        task_text_for_eval = [task_name.replace("_", " ")]
        print(f"Task text for VLA: {task_text_for_eval[0]}")

    successes  = []
    step_counts = []

    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        # Reset to recorded initial state
        env.reset()
        # Re-apply clean scene after reset (reset recreates sim)
        if args.clean_scene:
            sim = env.env.sim
            for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
                try:
                    sim.model.body_pos[sim.model.body_name2id(fname)] = np.array([0, 0, -5.0])
                except Exception: pass
            sim.forward()
            dist_bodies = set()
            for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
                try: dist_bodies.add(sim.model.body_name2id(dn))
                except Exception: pass
            for gid in range(sim.model.ngeom):
                if sim.model.geom_bodyid[gid] in dist_bodies:
                    sim.model.geom_rgba[gid][3] = 0.0

        init_state = init_states[ep_idx].copy()
        # Shift objects if requested
        if args.shift_dx != 0 or args.shift_dy != 0:
            # State layout: qpos_offset=1, bowl=qpos[9], plate=qpos[37]
            for qps in [9, 37]:  # pick and place objects
                si = qps + 1  # +1 for state prefix
                init_state[si] += args.shift_dx
                init_state[si + 1] += args.shift_dy
            # Move distractors off-screen
            for qps in [16, 23, 30]:  # distractor objects
                si = qps + 1
                init_state[si:si+3] = [10.0, 10.0, 0.9]
        obs = env.set_init_state(init_state)

        # Physics settlement: let the sim settle with zero action
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        # Reposition camera if viewpoint shift requested
        if args.cam_theta != 0 or args.cam_phi != 0:
            _reposition_camera(env.sim, args.camera, args.cam_theta, args.cam_phi)

        # Camera params (static for agentview; recompute each step if using wrist cam)
        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)

        done    = False
        success = False
        frames  = [] if args.save_video else None
        vis_strips = []  # per-replan-step visualization strips
        step_idx = 0
        replan_idx = 0
        current_gripper_cmd = -1.0  # track gripper state across entire episode

        while step_idx < args.max_steps and not done:
            current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            rgb_obs         = obs[f"{args.camera}_image"]   # (H, W, 3) uint8

            # Start keypoint: current EEF projected into training image space
            start_kp   = eef_to_start_kp(current_eef_pos, world_to_camera, IMAGE_SIZE).to(device)
            img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device)

            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            pred_3d_targets = None  # set by heatmap decoders, not ACT

            if args.model_type in ("act", "internvl_act"):
                # ACT-style: direct regression → decode to delta actions
                # Normalize proprioception to [0,1]
                min_pos = np.array(model_module.MIN_POS, dtype=np.float64)
                max_pos = np.array(model_module.MAX_POS, dtype=np.float64)
                eef_norm = torch.tensor(
                    (current_eef_pos - min_pos) / (max_pos - min_pos + 1e-8),
                    dtype=torch.float32, device=device
                ).clamp(0, 1).unsqueeze(0)
                # Get current gripper state from obs
                grip_state = float(obs.get("robot0_gripper_qpos", [0, 0])[0])  # first finger pos
                grip_norm = torch.tensor(
                    [(grip_state - model_module.MIN_GRIPPER) / (model_module.MAX_GRIPPER - model_module.MIN_GRIPPER + 1e-8)],
                    dtype=torch.float32, device=device
                ).clamp(0, 1).unsqueeze(0)
                act_extra = {}
                if clip_embedding is not None:
                    act_extra['clip_embedding'] = clip_embedding
                if task_text_for_eval is not None:
                    act_extra['task_text'] = task_text_for_eval
                with torch.no_grad():
                    pos_pred, rot_pred, gripper_pred = model(
                        img_tensor, start_kp,
                        current_eef_pos=eef_norm,
                        current_gripper=grip_norm,
                        **act_extra,
                    )
                window_actions = decode_act_actions(
                    pos_pred, rot_pred, gripper_pred,
                    current_eef_pos, current_eef_quat,
                )
                # Extract absolute 3D targets for teleport mode
                pos_denorm = pos_pred[0].cpu().numpy() * (max_pos - min_pos) + min_pos
                pred_3d_targets = [pos_denorm[t] for t in range(pos_denorm.shape[0])]
                volume_logits = None
            elif args.model_type in ("dual_da3", "dual_para"):
                # Dual-camera: process both views
                wrist_obs = obs["robot0_eye_in_hand_image"]
                wrist_tensor = preprocess_obs(wrist_obs, IMAGE_SIZE).to(device)
                # Wrist camera params (re-get each step since wrist moves)
                wrist_w2c, wrist_cam_pose, wrist_cam_K = get_camera_params(
                    env.sim, "robot0_eye_in_hand", IMAGE_SIZE)
                with torch.no_grad():
                    out = model(img_tensor, wrist_tensor, start_keypoint_2d=start_kp)
                window_actions, _ = decode_dual_window_actions(
                    out['agent_volume'], out['wrist_volume'],
                    model, out['agent_feats'], out['wrist_feats'],
                    camera_pose, cam_K, wrist_cam_pose, wrist_cam_K,
                    wrist_w2c,
                    current_eef_pos, current_eef_quat,
                    image_size=IMAGE_SIZE,
                )
                volume_logits = out['agent_volume']  # for video rendering
            elif args.model_type == "cost_volume":
                # Cost volume: needs wrist image + camera params
                wrist_obs = obs["robot0_eye_in_hand_image"]
                wrist_tensor = preprocess_obs(wrist_obs, IMAGE_SIZE).to(device)
                wrist_w2c_mat, wrist_cam_pose_mat, wrist_cam_K_mat = get_camera_params(
                    env.sim, "robot0_eye_in_hand", IMAGE_SIZE)
                # Normalize intrinsics
                agent_K_norm = cam_K.copy()
                agent_K_norm[0] /= IMAGE_SIZE
                agent_K_norm[1] /= IMAGE_SIZE
                wrist_K_norm = wrist_cam_K_mat.copy()
                wrist_K_norm[0] /= IMAGE_SIZE
                wrist_K_norm[1] /= IMAGE_SIZE
                with torch.no_grad():
                    volume_logits, _, _, feats = model(
                        img_tensor, wrist_tensor, start_keypoint_2d=start_kp,
                        agent_cam_pose=torch.from_numpy(camera_pose).float().unsqueeze(0).to(device),
                        agent_cam_K_norm=torch.from_numpy(agent_K_norm).float().unsqueeze(0).to(device),
                        wrist_cam_pose=torch.from_numpy(wrist_cam_pose_mat).float().unsqueeze(0).to(device),
                        wrist_cam_K_norm=torch.from_numpy(wrist_K_norm).float().unsqueeze(0).to(device),
                    )
                window_actions, pred_3d_targets = decode_window_actions(
                    volume_logits, model, feats,
                    camera_pose, cam_K, current_eef_pos, current_eef_quat,
                    image_size=IMAGE_SIZE,
                )
            else:
                # Single-camera heatmap models (PARA, DA3, MoGe, DinoVLA, InternVL)
                extra_kwargs = {}
                if clip_embedding is not None:
                    extra_kwargs['clip_embedding'] = clip_embedding
                if task_text_for_eval is not None:
                    extra_kwargs['task_text'] = task_text_for_eval
                with torch.no_grad():
                    volume_logits, _, _, feats = model(img_tensor, start_kp, **extra_kwargs)
                window_actions, pred_3d_targets = decode_window_actions(
                    volume_logits, model, feats,
                    camera_pose, cam_K, current_eef_pos, current_eef_quat,
                    image_size=IMAGE_SIZE,
                )

            # Save per-replan visualization strip (all N_WINDOW heatmaps side by side)
            save_vis = getattr(args, 'save_vis', False)
            if save_vis and volume_logits is not None:
                strip = render_window_strip(
                    rgb_obs, volume_logits, current_eef_pos,
                    world_to_camera, cam_K, IMAGE_SIZE, step_idx,
                )
                vis_strips.append((replan_idx, step_idx, strip))
            replan_idx += 1

            # Execute the window
            for t, action in enumerate(window_actions):
                if args.zero_rotation:
                    action[3:6] = 0.0

                if frames is not None:
                    if volume_logits is not None:
                        frames.append(render_eval_frame(
                            obs[f"{args.camera}_image"], volume_logits, current_eef_pos,
                            world_to_camera, cam_K, IMAGE_SIZE, step_idx, success=None,
                        ))
                    else:
                        frame = obs[f"{args.camera}_image"].astype(np.float32) / 255.0
                        frame = np.flipud(frame).copy()
                        frame = cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)
                        vis = (frame * 255.0).astype(np.uint8)
                        label = f"step {step_idx}"
                        cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
                        cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)
                        frames.append(vis)

                if args.teleport and pred_3d_targets is not None:
                    # Closed-loop servo to predicted 3D target with rotation, then apply gripper
                    target_pos = pred_3d_targets[t].astype(np.float64)
                    pred_rot_delta = action[3:6].copy()  # predicted rotation delta
                    new_gripper = action[6]
                    servo_steps = 0
                    max_servo = 25
                    threshold = 0.005  # 5mm

                    # Phase 1: servo to target position + rotation, holding current gripper
                    while servo_steps < max_servo:
                        cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                        delta = target_pos - cur_pos
                        dist = np.linalg.norm(delta)
                        if dist < threshold:
                            break
                        delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
                        servo_action = np.zeros(7, dtype=np.float32)
                        servo_action[:3] = delta_clipped
                        if not args.zero_rotation:
                            servo_action[3:6] = pred_rot_delta  # apply predicted rotation
                        servo_action[6] = current_gripper_cmd
                        obs, _, done, _ = env.step(servo_action)
                        step_idx += 1
                        servo_steps += 1
                        if done:
                            success = True
                            break
                        if step_idx >= args.max_steps:
                            break

                    # Phase 2: at target position, apply new gripper
                    if not done and step_idx < args.max_steps:
                        grip_action = np.zeros(7, dtype=np.float32)
                        grip_action[6] = new_gripper
                        obs, _, done, _ = env.step(grip_action)
                        step_idx += 1
                        current_gripper_cmd = new_gripper
                elif args.move_then_grip:
                    # Step 1: same action but with previous gripper (move, hold gripper)
                    move_action = action.copy()
                    move_action[6] = current_gripper_cmd
                    obs, _, done, _ = env.step(move_action)
                    step_idx += 1
                    if done:
                        success = True
                        break
                    if step_idx >= args.max_steps:
                        break
                    # Step 2: same action but with new gripper (correction + grip)
                    obs, _, done, _ = env.step(action)
                    current_gripper_cmd = action[6]
                    step_idx += 1
                elif args.duplicate_actions:
                    # Sanity check: execute each action twice (should match normal)
                    obs, _, done, _ = env.step(action)
                    current_gripper_cmd = action[6]
                    step_idx += 1
                    if done:
                        success = True
                        break
                    if step_idx >= args.max_steps:
                        break
                    obs, _, done, _ = env.step(action)
                    step_idx += 1
                else:
                    # Simultaneous move + gripper (default)
                    obs, _, done, _ = env.step(action)
                    current_gripper_cmd = action[6]
                    step_idx += 1

                if done:
                    success = True
                    break
                if step_idx >= args.max_steps:
                    break

        # Annotate last frame with success/failure
        if frames:
            if volume_logits is not None:
                frames[-1] = render_eval_frame(
                    obs[f"{args.camera}_image"], volume_logits,
                    np.array(obs["robot0_eef_pos"], dtype=np.float64),
                    world_to_camera, cam_K, IMAGE_SIZE, step_idx, success=success,
                )
            else:
                # ACT: annotate last raw frame
                tag_text = "SUCCESS" if success else "FAILURE"
                cv2.putText(frames[-1], tag_text, (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
                            (0, 255, 0) if success else (0, 0, 255), 2, cv2.LINE_AA)
            video_dir = Path(args.out_dir) / "videos" / f"task_{args.task_id}"
            video_dir.mkdir(parents=True, exist_ok=True)
            video_path = video_dir / f"ep{ep_idx:03d}_{'success' if success else 'fail'}.mp4"
            writer = cv2.VideoWriter(
                str(video_path),
                cv2.VideoWriter_fourcc(*"mp4v"),
                args.video_fps,
                (IMAGE_SIZE, IMAGE_SIZE),
            )
            for f in frames:
                writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
            writer.release()

        # Save per-replan visualization strips as PNGs
        if vis_strips:
            tag = "success" if success else "fail"
            vis_dir = Path(args.out_dir) / "vis" / f"task_{args.task_id}" / f"ep{ep_idx:03d}_{tag}"
            vis_dir.mkdir(parents=True, exist_ok=True)
            for replan_i, step_i, strip_img in vis_strips:
                vis_path = vis_dir / f"replan{replan_i:03d}_step{step_i:04d}.png"
                cv2.imwrite(str(vis_path), cv2.cvtColor(strip_img, cv2.COLOR_RGB2BGR))

        successes.append(float(success))
        step_counts.append(step_idx + 1)
        tqdm.write(
            f"  Ep {ep_idx+1:3d}: {'✓ SUCCESS' if success else '✗ FAILURE'}"
            f"  steps={step_idx+1}"
        )

    env.close()

    success_rate = float(np.mean(successes))
    avg_steps    = float(np.mean(step_counts))
    print(f"\n{'='*52}")
    print(f"  Benchmark:    {args.benchmark}")
    print(f"  Task {args.task_id}:      {task_name}")
    print(f"  Episodes:     {n_episodes}")
    print(f"  Successes:    {int(sum(successes))} / {n_episodes}")
    print(f"  Success Rate: {success_rate * 100:.1f}%")
    print(f"  Avg steps:    {avg_steps:.1f} / {args.max_steps}")
    print(f"{'='*52}")

    # --- Save results ---
    results = {
        "benchmark":    args.benchmark,
        "task_id":      args.task_id,
        "task_name":    task_name,
        "checkpoint":   str(ckpt_path),
        "n_episodes":   n_episodes,
        "success_rate": success_rate,
        "successes":    successes,
        "step_counts":  step_counts,
        "avg_steps":    avg_steps,
        "max_steps":    args.max_steps,
        "max_delta":    0.05,
    }
    out_dir  = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"eval_{args.benchmark}_task{args.task_id}.json"
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved → {out_path}")
    return results


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate PARA in LIBERO simulation")
    parser.add_argument("--model_type",  type=str,   default="para",
                        choices=["para", "act", "da3", "moge", "dino_vla", "internvl", "internvl_act", "dual_da3", "dual_para", "wrist_only", "cost_volume"],
                        help="Model architecture to evaluate")
    parser.add_argument("--model_name",  type=str,   default="OpenGVLab/InternVL2_5-1B",
                        help="HuggingFace model name (used by internvl model_type)")
    parser.add_argument("--checkpoint",  type=str,   required=True,
                        help="Path to .pth checkpoint (e.g. libero/checkpoints/para_libero_spatial_t0/best.pth)")
    parser.add_argument("--benchmark",  type=str,   default="libero_spatial",
                        help="LIBERO benchmark name (libero_spatial, libero_goal, libero_object, libero_10)")
    parser.add_argument("--task_id",    type=int,   default=0)
    parser.add_argument("--camera",     type=str,   default="agentview")
    parser.add_argument("--n_episodes", type=int,   default=20,
                        help="Number of rollout episodes to evaluate")
    parser.add_argument("--max_steps",  type=int,   default=300,
                        help="Max env steps per episode before failure")
    parser.add_argument("--seed",       type=int,   default=0)
    parser.add_argument("--out_dir",    type=str,   default="libero/out/eval")
    parser.add_argument("--save_video", action="store_true",
                        help="Save per-episode MP4 with heatmap overlay to out_dir/videos/")
    parser.add_argument("--save_vis", action="store_true",
                        help="Save per-replan-step visualization strips (all N_WINDOW heatmaps) as PNGs")
    parser.add_argument("--video_fps",  type=int,   default=10,
                        help="FPS for saved videos (default 10)")
    parser.add_argument("--clip_embeddings_dir", type=str, default="/data/libero/parsed_libero",
                        help="Directory containing precomputed CLIP embeddings (for dino_vla)")
    parser.add_argument("--zero_rotation", action="store_true",
                        help="Zero out rotation deltas (position-only control, for diagnosing rotation issues)")
    parser.add_argument("--clean_scene", action="store_true",
                        help="Remove distractors and furniture (match OOD objpos training setup)")
    parser.add_argument("--shift_dx", type=float, default=0.0,
                        help="Shift pick/place objects by dx in world X")
    parser.add_argument("--shift_dy", type=float, default=0.0,
                        help="Shift pick/place objects by dy in world Y")
    parser.add_argument("--cam_theta", type=float, default=0.0,
                        help="Camera viewpoint polar angle from default (degrees)")
    parser.add_argument("--cam_phi", type=float, default=0.0,
                        help="Camera viewpoint azimuthal angle (degrees)")
    parser.add_argument("--teleport", action="store_true",
                        help="Servo to predicted 3D targets with closed-loop control (bypasses open-loop delta execution)")
    parser.add_argument("--move_then_grip", action="store_true",
                        help="Execute EEF move and gripper as separate steps (move first, then grip)")
    parser.add_argument("--duplicate_actions", action="store_true",
                        help="Execute each action twice [a1,a1,a2,a2,...] for sanity checking (should match normal)")
    args = parser.parse_args()
    run_eval(args)
