"""eval.py — Evaluate a trained PARA checkpoint in the LIBERO simulation environment.

PARA predicts the next N_WINDOW absolute EEF 3D positions from a single RGB image.
This script runs closed-loop rollouts: at each env step, re-run the model on the
current observation, decode the first predicted position into a delta OSC_POSE action,
and step the sim. Success is the LIBERO binary predicate check (env returns done=True).

Usage:
    python libero/eval.py \
        --checkpoint libero/checkpoints/para_libero_spatial_t0/best.pth \
        --benchmark libero_spatial \
        --task_id 0 \
        --n_episodes 20

Action format: 7D OSC_POSE [delta_pos (3), delta_rot (3)=0, gripper (1)]
"""

import argparse
import json
import os
import sys
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

sys.path.insert(0, os.path.dirname(__file__))

import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height


def get_model_class(model_type):
    if model_type == "para":
        return TrajectoryHeatmapPredictor
    elif model_type == "act":
        from model_act import ACTPredictor
        return ACTPredictor
    elif model_type == "da3":
        from model_da3 import DA3Predictor
        return DA3Predictor
    elif model_type == "moge":
        from model_moge import MoGePredictor
        return MoGePredictor
    elif model_type == "dino_vla":
        from model_dino_vla import DinoVLAPredictor
        return DinoVLAPredictor
    elif model_type == "internvl":
        from model_vla_internvl import InternVLAPredictor
        return InternVLAPredictor
    elif model_type == "internvl_act":
        from model_vla_internvl_act import InternVLACTPredictor
        return InternVLACTPredictor
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
IMAGE_SIZE = 448


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def preprocess_obs(rgb_obs, image_size=IMAGE_SIZE):
    """HxWx3 uint8 → (1, 3, H, W) float tensor, ImageNet-normalized.

    LIBERO obs images are already upright (already flipped vs raw render),
    so we flipud to match training convention (flipud(obs) → training image).
    """
    img = rgb_obs.astype(np.float32) / 255.0
    img = np.flipud(img).copy()                          # match training image convention
    img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0)  # (1, 3, H, W)
    return tensor


def get_camera_params(sim, camera_name, image_size=IMAGE_SIZE):
    """Return camera matrices needed for projection and 3D recovery.

    - world_to_camera: (4,4) world→camera transform  → for project_points_from_world_to_camera
    - camera_pose:     (4,4) camera→world transform   → for recover_3d (ray unprojection)
    - cam_K:           (3,3) intrinsic at image_size  → for recover_3d
    These are two different matrices; using the wrong one for 3D recovery gives bad targets.
    """
    world_to_camera = get_camera_transform_matrix(sim, camera_name, image_size, image_size)
    camera_pose     = get_camera_extrinsic_matrix(sim, camera_name)   # camera→world
    cam_K_norm = get_camera_intrinsic_matrix(sim, camera_name, image_size, image_size)
    cam_K_norm[0] /= image_size
    cam_K_norm[1] /= image_size
    cam_K = cam_K_norm.copy()
    cam_K[0] *= image_size
    cam_K[1] *= image_size
    return world_to_camera, camera_pose, cam_K


def eef_to_start_kp(eef_pos, world_to_camera, image_size=IMAGE_SIZE):
    """Project current EEF world position → (u, v) pixel in training image convention."""
    pix_rc = project_points_from_world_to_camera(
        points=eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera_transform=world_to_camera,
        camera_height=image_size,
        camera_width=image_size,
    )[0]
    u = float(pix_rc[1])  # col
    v = float(pix_rc[0])  # row (same convention as training)
    return torch.tensor([u, v], dtype=torch.float32)


def render_eval_frame(rgb_obs, volume_logits, current_eef_pos,
                      world_to_camera, cam_K, image_size=IMAGE_SIZE, step_idx=0, success=None):
    """Render a single eval step: RGB + heatmap overlay + predicted pixel + GT EEF dot."""
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    # Preprocess frame for display (flipud to match training convention)
    frame = rgb_obs.astype(np.float32) / 255.0
    frame = np.flipud(frame).copy()
    frame = cv2.resize(frame, (image_size, image_size), interpolation=cv2.INTER_LINEAR)

    # Heatmap from first predicted timestep
    vol_t = volume_logits[0, 0]                                      # (Nh, pred_size, pred_size)
    vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
    heat_small = vol_probs.max(dim=0)[0].cpu().numpy()               # (pred_size, pred_size)
    heat = cv2.resize(heat_small, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    heat = (heat - heat.min()) / (heat.max() + 1e-8)
    heat_rgb = np.zeros_like(frame)
    heat_rgb[..., 0] = heat
    overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
    vis = (overlay * 255.0).astype(np.uint8)

    # Predicted pixel (green crosshair)
    flat_idx = heat_small.argmax()
    py, px = flat_idx // pred_size, flat_idx % pred_size
    px_full = int((px + 0.5) * scale)
    py_full = int((py + 0.5) * scale)
    cv2.drawMarker(vis, (px_full, py_full), (0, 255, 0), cv2.MARKER_CROSS, 18, 2, cv2.LINE_AA)

    # GT EEF projection (white dot)
    pix_rc = project_points_from_world_to_camera(
        current_eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera, image_size, image_size,
    )[0]
    u, v = int(round(float(pix_rc[1]))), int(round(float(pix_rc[0])))
    if 0 <= u < image_size and 0 <= v < image_size:
        cv2.circle(vis, (u, v), 6, (255, 255, 255), -1)

    # Step counter + success indicator
    label = f"step {step_idx}"
    if success is not None:
        label += "  SUCCESS" if success else "  running"
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)

    return vis  # (H, W, 3) uint8 RGB


def render_window_strip(rgb_obs, volume_logits, current_eef_pos,
                        world_to_camera, cam_K, image_size=IMAGE_SIZE, step_idx=0):
    """Render a horizontal strip showing heatmaps for all N_WINDOW predicted timesteps.

    Each tile: RGB + heatmap overlay (red) + predicted pixel (green cross) + GT EEF (white dot).
    Returns a single wide image: (image_size, image_size * n_window, 3) uint8 RGB.
    """
    n_window = volume_logits.shape[1]
    pred_size = volume_logits.shape[-1]
    scale = image_size / pred_size

    # Preprocess frame for display (flipud to match training convention)
    frame = rgb_obs.astype(np.float32) / 255.0
    frame = np.flipud(frame).copy()
    frame = cv2.resize(frame, (image_size, image_size), interpolation=cv2.INTER_LINEAR)

    # GT EEF pixel (same for all timesteps — current position)
    pix_rc = project_points_from_world_to_camera(
        current_eef_pos.reshape(1, 3).astype(np.float64),
        world_to_camera, image_size, image_size,
    )[0]
    eef_u, eef_v = int(round(float(pix_rc[1]))), int(round(float(pix_rc[0])))

    tiles = []
    for t in range(n_window):
        vol_t = volume_logits[0, t]  # (Nh, pred_size, pred_size)
        vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
        heat_small = vol_probs.max(dim=0)[0].cpu().numpy()
        heat = cv2.resize(heat_small, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
        heat = (heat - heat.min()) / (heat.max() + 1e-8)

        heat_rgb = np.zeros_like(frame)
        heat_rgb[..., 0] = heat
        overlay = np.clip(frame * 0.55 + heat_rgb * 0.45, 0, 1)
        vis = (overlay * 255.0).astype(np.uint8)

        # Predicted pixel (green crosshair)
        flat_idx = heat_small.argmax()
        py, px = flat_idx // pred_size, flat_idx % pred_size
        px_full = int((px + 0.5) * scale)
        py_full = int((py + 0.5) * scale)
        cv2.drawMarker(vis, (px_full, py_full), (0, 255, 0), cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)

        # GT EEF (white dot)
        if 0 <= eef_u < image_size and 0 <= eef_v < image_size:
            cv2.circle(vis, (eef_u, eef_v), 5, (255, 255, 255), -1)

        # Timestep label
        label = f"step {step_idx} t+{t}"
        cv2.putText(vis, label, (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(vis, label, (8, 16), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (20, 20, 20), 1, cv2.LINE_AA)
        tiles.append(vis)

    return np.concatenate(tiles, axis=1)  # (H, W*n_window, 3)


def decode_window_actions(volume_logits, model, feats, camera_pose, cam_K,
                          current_eef_pos, current_eef_quat, image_size=IMAGE_SIZE, max_delta=0.05):
    """Decode all N_WINDOW predicted timesteps into OSC_POSE delta actions.

    Gripper/rotation are predicted by indexing features at the argmax pixel of each
    timestep and passing through the model's MLP heads (same as training inference path).

    Args:
        volume_logits:   (1, N_WINDOW, N_HEIGHT_BINS, pred_size, pred_size)
        model:           TrajectoryHeatmapPredictor (for predict_at_pixels)
        feats:           (1, D, pred_size, pred_size) feature map from forward()
        camera_pose:     (4,4) camera→world extrinsic  ← get_camera_extrinsic_matrix()
        cam_K:           (3,3) intrinsic at image_size
        current_eef_pos: (3,) numpy EEF position at start of window
        max_delta:       max position delta magnitude in metres before OSC normalisation

    Returns:
        actions:         list of N_WINDOW (7,) numpy [delta_pos(3), delta_rot_axisangle(3), gripper(1)]
        pred_3d_targets: list of N_WINDOW (3,) absolute EEF targets (for debug)
    """
    from scipy.spatial.transform import Rotation as ScipyR
    OSC_POS_SCALE = 0.05  # robosuite OSC_POSE: input [-1,1] → output [-0.05, 0.05] m
    OSC_ROT_SCALE = 0.5   # robosuite OSC_POSE: input [-1,1] → output [-0.5, 0.5] rad

    n_window  = volume_logits.shape[1]
    pred_size = volume_logits.shape[-1]
    scale     = image_size / pred_size
    min_h, max_h = model_module.MIN_HEIGHT, model_module.MAX_HEIGHT
    min_g, max_g = model_module.MIN_GRIPPER, model_module.MAX_GRIPPER
    min_r = np.array(model_module.MIN_ROT, dtype=np.float64)
    max_r = np.array(model_module.MAX_ROT, dtype=np.float64)

    # --- Pass 1: collect predicted pixel for each timestep ---
    pred_px_list = []  # list of (px, py) in pred_size space
    for t in range(n_window):
        vol_t      = volume_logits[0, t]          # (Nh, pred_size, pred_size)
        max_over_h = vol_t.max(dim=0)[0]
        flat_idx   = max_over_h.reshape(-1).argmax().item()
        py = flat_idx // pred_size
        px = flat_idx % pred_size
        pred_px_list.append((px, py))

    # --- Batch gripper/rotation prediction at all predicted pixels ---
    pred_pixels = torch.tensor(
        [[px, py] for px, py in pred_px_list], dtype=torch.float32, device=feats.device
    ).unsqueeze(0)  # (1, N_WINDOW, 2)
    with torch.no_grad():
        gripper_logits, rotation_logits = model.predict_at_pixels(feats, pred_pixels)
    # gripper_logits:  (1, N_WINDOW) raw logits
    # rotation_logits: (1, N_WINDOW, 3, N_ROT_BINS)

    # --- Pass 2: decode actions ---
    actions         = []
    pred_3d_targets = []
    ref_pos         = current_eef_pos.copy()

    for t, (px, py) in enumerate(pred_px_list):
        vol_t   = volume_logits[0, t]  # (Nh, pred_size, pred_size)
        px_full = (px + 0.5) * scale
        py_full = (py + 0.5) * scale

        # Height → 3D position
        h_bin  = vol_t[:, py, px].argmax().item()
        height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (max_h - min_h) + min_h
        pred_3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), height, camera_pose, cam_K,
        )
        if pred_3d is None:
            pred_3d = pred_3d_targets[-1] if pred_3d_targets else ref_pos.copy()
        pred_3d_targets.append(pred_3d)

        # Delta position → normalised OSC input
        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Rotation: decode bin → absolute euler → delta from current EEF orientation
        euler_pred = np.array([
            (rotation_logits[0, t, axis, :].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (max_r[axis] - min_r[axis]) + min_r[axis]
            for axis in range(3)
        ])
        R_pred        = ScipyR.from_euler('xyz', euler_pred)
        R_current     = ScipyR.from_quat(current_eef_quat)
        R_delta       = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        # Gripper: binary sigmoid → threshold → ±1 command
        g_logit = float(gripper_logits[0, t].cpu())
        gripper_cmd = 1.0 if g_logit > 0.0 else -1.0

        action      = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot_norm
        action[6]   = gripper_cmd
        actions.append(action)

    return actions, pred_3d_targets


def decode_act_actions(pos_pred, rot_pred, gripper_pred, current_eef_pos, current_eef_quat):
    """Decode ACT normalized [0,1] predictions into OSC_POSE delta actions.

    Model outputs are in [0,1] (sigmoid). We denormalize using dataset min/max,
    then compute deltas for the OSC_POSE controller.

    Args:
        pos_pred:     (1, N_WINDOW, 3) normalized [0,1] positions (tensor)
        rot_pred:     (1, N_WINDOW, 3) normalized [0,1] rotations (tensor)
        gripper_pred: (1, N_WINDOW) normalized [0,1] gripper (tensor)
        current_eef_pos:  (3,) numpy
        current_eef_quat: (4,) numpy

    Returns:
        actions: list of N_WINDOW (7,) numpy [delta_pos(3), delta_rot(3), gripper(1)]
    """
    from scipy.spatial.transform import Rotation as ScipyR
    OSC_POS_SCALE = 0.05
    OSC_ROT_SCALE = 0.5

    min_pos = np.array(model_module.MIN_POS, dtype=np.float64)
    max_pos = np.array(model_module.MAX_POS, dtype=np.float64)
    min_rot = np.array(model_module.MIN_ROT, dtype=np.float64)
    max_rot = np.array(model_module.MAX_ROT, dtype=np.float64)
    min_g = float(model_module.MIN_GRIPPER)
    max_g = float(model_module.MAX_GRIPPER)

    n_window = pos_pred.shape[1]
    actions  = []
    ref_pos  = current_eef_pos.copy()

    for t in range(n_window):
        # Denormalize from [0,1] to original scale
        pos_norm = pos_pred[0, t].cpu().numpy().astype(np.float64)
        pred_3d = pos_norm * (max_pos - min_pos) + min_pos

        rot_norm = rot_pred[0, t].cpu().numpy().astype(np.float64)
        euler_pred = rot_norm * (max_rot - min_rot) + min_rot

        g_norm = float(gripper_pred[0, t].cpu())
        gripper_val = g_norm * (max_g - min_g) + min_g

        # Delta position
        delta_pos = pred_3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > 0.05:
            delta_pos = delta_pos / norm * 0.05
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Delta rotation
        R_pred     = ScipyR.from_euler('xyz', euler_pred)
        R_current  = ScipyR.from_quat(current_eef_quat)
        R_delta    = R_pred * R_current.inv()
        delta_rot  = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        # Gripper: raw logit → sigmoid → threshold → ±1 command
        g_logit = float(gripper_pred[0, t].cpu())
        gripper_cmd = 1.0 if g_logit > 0.0 else -1.0  # logit > 0 means P(close) > 0.5

        # Move first (keep previous gripper state), then apply gripper
        action      = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot
        action[6]   = gripper_cmd
        actions.append(action)

    return actions


# ---------------------------------------------------------------------------
# Main eval loop
# ---------------------------------------------------------------------------

def run_eval(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps"  if torch.backends.mps.is_available() else
        "cpu"
    )
    print(f"Device: {device}")

    # --- Load checkpoint ---
    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu")

    # Restore height/gripper range from checkpoint
    model_module.MIN_HEIGHT  = float(ckpt.get("min_height",  model_module.MIN_HEIGHT))
    model_module.MAX_HEIGHT  = float(ckpt.get("max_height",  model_module.MAX_HEIGHT))
    model_module.MIN_GRIPPER = float(ckpt.get("min_gripper", model_module.MIN_GRIPPER))
    model_module.MAX_GRIPPER = float(ckpt.get("max_gripper", model_module.MAX_GRIPPER))
    if "min_rot" in ckpt:
        model_module.MIN_ROT = ckpt["min_rot"]
        model_module.MAX_ROT = ckpt["max_rot"]
    if "min_pos" in ckpt:
        model_module.MIN_POS = ckpt["min_pos"]
        model_module.MAX_POS = ckpt["max_pos"]
    print(f"Height  range: [{model_module.MIN_HEIGHT:.4f}, {model_module.MAX_HEIGHT:.4f}]")
    print(f"Gripper range: [{model_module.MIN_GRIPPER:.4f}, {model_module.MAX_GRIPPER:.4f}]")
    print(f"Rot     range: {[f'{v:.3f}' for v in model_module.MIN_ROT]} .. {[f'{v:.3f}' for v in model_module.MAX_ROT]}")
    print(f"Pos     range: {[f'{v:.3f}' for v in model_module.MIN_POS]} .. {[f'{v:.3f}' for v in model_module.MAX_POS]}")

    ModelClass = get_model_class(args.model_type)
    if args.model_type == "para":
        model = ModelClass(target_size=IMAGE_SIZE, pred_size=PRED_SIZE)
    elif args.model_type in ("internvl", "internvl_act"):
        model = ModelClass(target_size=IMAGE_SIZE, model_name=args.model_name)
    else:
        model = ModelClass(target_size=IMAGE_SIZE)
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model = model.to(device)
    model.eval()
    print(f"Loaded model ({args.model_type}) from {ckpt_path}")

    # --- LIBERO benchmark + task ---
    bench     = bm.get_benchmark_dict()[args.benchmark]()
    task      = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: [{args.benchmark}] {task_name}")

    # Load initial states from demo HDF5 (frame 0 of each demo)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]  # first frame per demo
    n_episodes = min(args.n_episodes, len(init_states))
    print(f"Running {n_episodes} / {len(init_states)} episodes...")

    # --- Build env (default OSC_POSE controller) ---
    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(args.seed)
    env.reset()
    print("Environment ready.")

    # Load CLIP embedding for models with task conditioning
    clip_embedding = None
    if args.model_type in ("dino_vla", "act"):
        clip_path = os.path.join(args.clip_embeddings_dir, args.benchmark, f"task_{args.task_id}_clip.pt")
        if os.path.exists(clip_path):
            clip_embedding = torch.load(clip_path, map_location=device).unsqueeze(0)  # (1, D_clip)
            print(f"Loaded CLIP embedding: {clip_path}")
        else:
            print(f"WARNING: CLIP embedding not found at {clip_path}, using zeros")
            clip_embedding = torch.zeros(1, 512, device=device)

    # Load task description for internvl VLA
    task_text_for_eval = None
    if args.model_type in ("internvl", "internvl_act"):
        task_text_for_eval = [task_name.replace("_", " ")]
        print(f"Task text for VLA: {task_text_for_eval[0]}")

    successes  = []
    step_counts = []

    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        # Reset to recorded initial state
        env.reset()
        obs = env.set_init_state(init_states[ep_idx])

        # Physics settlement: let the sim settle with zero action
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        # Camera params (static for agentview; recompute each step if using wrist cam)
        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)

        done    = False
        success = False
        frames  = [] if args.save_video else None
        vis_strips = []  # per-replan-step visualization strips
        step_idx = 0
        replan_idx = 0

        while step_idx < args.max_steps and not done:
            current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            rgb_obs         = obs[f"{args.camera}_image"]   # (H, W, 3) uint8

            # Start keypoint: current EEF projected into training image space
            start_kp   = eef_to_start_kp(current_eef_pos, world_to_camera, IMAGE_SIZE).to(device)
            img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device)

            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)

            if args.model_type in ("act", "internvl_act"):
                # ACT-style: direct regression → decode to delta actions
                # Normalize proprioception to [0,1]
                min_pos = np.array(model_module.MIN_POS, dtype=np.float64)
                max_pos = np.array(model_module.MAX_POS, dtype=np.float64)
                eef_norm = torch.tensor(
                    (current_eef_pos - min_pos) / (max_pos - min_pos + 1e-8),
                    dtype=torch.float32, device=device
                ).clamp(0, 1).unsqueeze(0)
                # Get current gripper state from obs
                grip_state = float(obs.get("robot0_gripper_qpos", [0, 0])[0])  # first finger pos
                grip_norm = torch.tensor(
                    [(grip_state - model_module.MIN_GRIPPER) / (model_module.MAX_GRIPPER - model_module.MIN_GRIPPER + 1e-8)],
                    dtype=torch.float32, device=device
                ).clamp(0, 1).unsqueeze(0)
                act_extra = {}
                if clip_embedding is not None:
                    act_extra['clip_embedding'] = clip_embedding
                if task_text_for_eval is not None:
                    act_extra['task_text'] = task_text_for_eval
                with torch.no_grad():
                    pos_pred, rot_pred, gripper_pred = model(
                        img_tensor, start_kp,
                        current_eef_pos=eef_norm,
                        current_gripper=grip_norm,
                        **act_extra,
                    )
                window_actions = decode_act_actions(
                    pos_pred, rot_pred, gripper_pred,
                    current_eef_pos, current_eef_quat,
                )
                volume_logits = None
            else:
                # Heatmap models (PARA, DA3, MoGe, DinoVLA, InternVL)
                extra_kwargs = {}
                if clip_embedding is not None:
                    extra_kwargs['clip_embedding'] = clip_embedding
                if task_text_for_eval is not None:
                    extra_kwargs['task_text'] = task_text_for_eval
                with torch.no_grad():
                    volume_logits, _, _, feats = model(img_tensor, start_kp, **extra_kwargs)
                window_actions, _ = decode_window_actions(
                    volume_logits, model, feats,
                    camera_pose, cam_K, current_eef_pos, current_eef_quat,
                    image_size=IMAGE_SIZE,
                )

            # Save per-replan visualization strip (all N_WINDOW heatmaps side by side)
            save_vis = getattr(args, 'save_vis', False)
            if save_vis and volume_logits is not None:
                strip = render_window_strip(
                    rgb_obs, volume_logits, current_eef_pos,
                    world_to_camera, cam_K, IMAGE_SIZE, step_idx,
                )
                vis_strips.append((replan_idx, step_idx, strip))
            replan_idx += 1

            # Execute the window open-loop, re-rendering each step if saving video
            for t, action in enumerate(window_actions):
                if frames is not None:
                    if volume_logits is not None:
                        frames.append(render_eval_frame(
                            obs[f"{args.camera}_image"], volume_logits, current_eef_pos,
                            world_to_camera, cam_K, IMAGE_SIZE, step_idx, success=None,
                        ))
                    else:
                        # ACT: no heatmap, just save raw RGB with step label
                        frame = obs[f"{args.camera}_image"].astype(np.float32) / 255.0
                        frame = np.flipud(frame).copy()
                        frame = cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)
                        vis = (frame * 255.0).astype(np.uint8)
                        label = f"step {step_idx}"
                        cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
                        cv2.putText(vis, label, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)
                        frames.append(vis)

                obs, _, done, _ = env.step(action)
                step_idx += 1

                if done:
                    success = True
                    break
                if step_idx >= args.max_steps:
                    break

        # Annotate last frame with success/failure
        if frames:
            if volume_logits is not None:
                frames[-1] = render_eval_frame(
                    obs[f"{args.camera}_image"], volume_logits,
                    np.array(obs["robot0_eef_pos"], dtype=np.float64),
                    world_to_camera, cam_K, IMAGE_SIZE, step_idx, success=success,
                )
            else:
                # ACT: annotate last raw frame
                tag_text = "SUCCESS" if success else "FAILURE"
                cv2.putText(frames[-1], tag_text, (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
                            (0, 255, 0) if success else (0, 0, 255), 2, cv2.LINE_AA)
            video_dir = Path(args.out_dir) / "videos" / f"task_{args.task_id}"
            video_dir.mkdir(parents=True, exist_ok=True)
            video_path = video_dir / f"ep{ep_idx:03d}_{'success' if success else 'fail'}.mp4"
            writer = cv2.VideoWriter(
                str(video_path),
                cv2.VideoWriter_fourcc(*"mp4v"),
                args.video_fps,
                (IMAGE_SIZE, IMAGE_SIZE),
            )
            for f in frames:
                writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
            writer.release()

        # Save per-replan visualization strips as PNGs
        if vis_strips:
            tag = "success" if success else "fail"
            vis_dir = Path(args.out_dir) / "vis" / f"task_{args.task_id}" / f"ep{ep_idx:03d}_{tag}"
            vis_dir.mkdir(parents=True, exist_ok=True)
            for replan_i, step_i, strip_img in vis_strips:
                vis_path = vis_dir / f"replan{replan_i:03d}_step{step_i:04d}.png"
                cv2.imwrite(str(vis_path), cv2.cvtColor(strip_img, cv2.COLOR_RGB2BGR))

        successes.append(float(success))
        step_counts.append(step_idx + 1)
        tqdm.write(
            f"  Ep {ep_idx+1:3d}: {'✓ SUCCESS' if success else '✗ FAILURE'}"
            f"  steps={step_idx+1}"
        )

    env.close()

    success_rate = float(np.mean(successes))
    avg_steps    = float(np.mean(step_counts))
    print(f"\n{'='*52}")
    print(f"  Benchmark:    {args.benchmark}")
    print(f"  Task {args.task_id}:      {task_name}")
    print(f"  Episodes:     {n_episodes}")
    print(f"  Successes:    {int(sum(successes))} / {n_episodes}")
    print(f"  Success Rate: {success_rate * 100:.1f}%")
    print(f"  Avg steps:    {avg_steps:.1f} / {args.max_steps}")
    print(f"{'='*52}")

    # --- Save results ---
    results = {
        "benchmark":    args.benchmark,
        "task_id":      args.task_id,
        "task_name":    task_name,
        "checkpoint":   str(ckpt_path),
        "n_episodes":   n_episodes,
        "success_rate": success_rate,
        "successes":    successes,
        "step_counts":  step_counts,
        "avg_steps":    avg_steps,
        "max_steps":    args.max_steps,
        "max_delta":    0.05,
    }
    out_dir  = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"eval_{args.benchmark}_task{args.task_id}.json"
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved → {out_path}")
    return results


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate PARA in LIBERO simulation")
    parser.add_argument("--model_type",  type=str,   default="para",
                        choices=["para", "act", "da3", "moge", "dino_vla", "internvl", "internvl_act"],
                        help="Model architecture to evaluate")
    parser.add_argument("--model_name",  type=str,   default="OpenGVLab/InternVL2_5-1B",
                        help="HuggingFace model name (used by internvl model_type)")
    parser.add_argument("--checkpoint",  type=str,   required=True,
                        help="Path to .pth checkpoint (e.g. libero/checkpoints/para_libero_spatial_t0/best.pth)")
    parser.add_argument("--benchmark",  type=str,   default="libero_spatial",
                        help="LIBERO benchmark name (libero_spatial, libero_goal, libero_object, libero_10)")
    parser.add_argument("--task_id",    type=int,   default=0)
    parser.add_argument("--camera",     type=str,   default="agentview")
    parser.add_argument("--n_episodes", type=int,   default=20,
                        help="Number of rollout episodes to evaluate")
    parser.add_argument("--max_steps",  type=int,   default=300,
                        help="Max env steps per episode before failure")
    parser.add_argument("--seed",       type=int,   default=0)
    parser.add_argument("--out_dir",    type=str,   default="libero/out/eval")
    parser.add_argument("--save_video", action="store_true",
                        help="Save per-episode MP4 with heatmap overlay to out_dir/videos/")
    parser.add_argument("--save_vis", action="store_true",
                        help="Save per-replan-step visualization strips (all N_WINDOW heatmaps) as PNGs")
    parser.add_argument("--video_fps",  type=int,   default=10,
                        help="FPS for saved videos (default 10)")
    parser.add_argument("--clip_embeddings_dir", type=str, default="/data/libero/parsed_libero",
                        help="Directory containing precomputed CLIP embeddings (for dino_vla)")
    args = parser.parse_args()
    run_eval(args)
