"""Visualize trajectory predictions on the real dataset (2D: GT vs predicted trajectory).

Uses the volume model: volume_logits + gripper_logits; extracts pred 2D/height/gripper
and plots GT trajectory vs predicted trajectory for a given episode and start frame.

Paper figure mode (--paper): 1) imshow RGB (show then close); 2) image grid of heatmaps
for the last predicted timestep (all height channels, min/max normalized); 3) 3D volume
of last timestep heatmaps with transparency proportional to intensity.
"""

import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import torch
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.mplot3d import Axes3D  # register 3d projection

_VDT_DIR = os.path.dirname(__file__)
_REPO_ROOT = os.path.join(_VDT_DIR, "..")
sys.path.insert(0, _VDT_DIR)
sys.path.insert(0, _REPO_ROOT)

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS
import model as model_module
from utils import recover_3d_from_direct_keypoint_and_height

# Default checkpoint: volume_dino_tracks (volume + per-pixel gripper heads)
IMAGE_SIZE = 448
_SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CHECKPOINT_PATH = str(_SCRIPT_DIR / "checkpoints" / "best.pth")


def _denorm_rgb(rgb_bchw: torch.Tensor) -> np.ndarray:
    """(1,3,H,W) normalized -> (H,W,3) float [0,1]."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_bchw.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=rgb_bchw.device).view(1, 3, 1, 1)
    rgb_denorm = (rgb_bchw * std + mean).detach().cpu().numpy()[0]
    return np.clip(rgb_denorm.transpose(1, 2, 0), 0, 1)


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep."""
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]
        max_over_h, _ = vol_t.max(dim=1)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep."""
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    logits_at_pixels = gripper_logits[batch_idx, time_idx, :, py, px]
    return logits_at_pixels


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits to continuous values in [MIN_GRIPPER, MAX_GRIPPER]."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    bin_indices = bin_logits.argmax(dim=-1)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)
    normalized = bin_centers[bin_indices]
    return normalized * (max_g - min_g) + min_g


def run_volume_model(model, rgb, start_keypoint_2d, device):
    """Run volume model and return pred_2d, pred_height, pred_gripper, volume_logits (for heatmap)."""
    with torch.no_grad():
        volume_logits, gripper_logits = model(
            rgb,
            gt_target_heatmap=None,
            training=False,
            start_keypoint_2d=start_keypoint_2d,
            current_height=None,
            current_gripper=None,
        )
    pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
    gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)
    pred_gripper = decode_gripper_bins(gripper_logits_at_pred)
    return pred_2d, pred_height, pred_gripper, volume_logits


MAX_EPISODES_PER_ROW = 5


def generate_paper_figure(model, dataset, sample_idx, save_dir, close_after_show=True):
    """Generate paper figure: 1) RGB imshow (show then close); 2) heatmap grid for last timestep (min/max norm); 3) 3D volume with transparency."""
    device = next(model.parameters()).device
    model.eval()
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    sample = dataset[sample_idx]
    rgb = sample["rgb"].unsqueeze(0).to(device)
    trajectory_2d = sample["trajectory_2d"].cpu().numpy()
    start_keypoint_2d = torch.tensor(trajectory_2d[0], device=device, dtype=torch.float32)

    with torch.no_grad():
        volume_logits, _ = model(
            rgb,
            gt_target_heatmap=None,
            training=False,
            start_keypoint_2d=start_keypoint_2d,
            current_height=None,
            current_gripper=None,
        )
    # volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, H, W)
    vol_last = volume_logits[0, -1].cpu().numpy()  # (N_HEIGHT_BINS, H, W)
    rgb_vis = _denorm_rgb(rgb)


    if 0: # skipping first 2 for now

        # --- 1) RGB imshow; show then close ---
        fig1, ax1 = plt.subplots(1, 1, figsize=(6, 6))
        ax1.imshow(rgb_vis)
        ax1.set_title("Input RGB")
        ax1.axis("off")
        plt.tight_layout()
        plt.savefig(save_dir / "paper_fig_rgb.png", dpi=150, bbox_inches="tight")
        print(f"Saved: {save_dir / 'paper_fig_rgb.png'}")
        plt.show()
        plt.close()
            # --- 2) Image grid of heatmaps for last timestep (all height channels), min/max over all values ---
        vmin = float(vol_last.min())
        vmax = float(vol_last.max())
        if vmax <= vmin:
            vmax = vmin + 1e-8
        n_slices = vol_last.shape[0]  # N_HEIGHT_BINS
        for h in range(n_slices):
            fig_h, ax_h = plt.subplots(1, 1, figsize=(5, 5))
            ax_h.imshow(vol_last[h], cmap="hot", vmin=vmin, vmax=vmax)
            ax_h.set_title(f"height bin {h}")
            ax_h.axis("off")
            plt.tight_layout()
            plt.savefig(save_dir / f"heatmap_{h:03d}.png", dpi=150, bbox_inches="tight")
            plt.close(fig_h)
        print(f"Saved {n_slices} heatmaps: heatmap_000.png .. heatmap_{n_slices-1:03d}.png")

        n_cols = 8
        n_rows = (n_slices + n_cols - 1) // n_cols
        fig2 = plt.figure(figsize=(2 * n_cols, 2 * n_rows))
        grid = ImageGrid(fig2, 111, nrows_ncols=(n_rows, n_cols), axes_pad=0.15, share_all=True)
        for h in range(n_slices):
            grid[h].imshow(vol_last[h], cmap="hot", vmin=vmin, vmax=vmax)
            grid[h].set_title(f"h={h}", fontsize=8)
            grid[h].axis("off")
        for j in range(n_slices, len(grid)):
            grid[j].axis("off")
        plt.suptitle("Last timestep: heatmaps per height channel (min/max normalized)", fontsize=11)
        plt.savefig(save_dir / "paper_fig_heatmap_grid.png", dpi=150, bbox_inches="tight")
        print(f"Saved: {save_dir / 'paper_fig_heatmap_grid.png'}")
        plt.show()
        plt.close()

    # --- GT robot overlay: render last-timestep joint state at GT image resolution, then overlay ---
    episode_dir, start_frame_idx = dataset.samples[sample_idx]
    frame_str = f"{start_frame_idx:06d}"
    rgb_gt_path = episode_dir / f"{frame_str}.png"
    # Load GT image at original resolution (same convention as dataset: fx/cx scaled by W, fy/cy by H)
    rgb_gt_orig = plt.imread(rgb_gt_path)[..., :3]
    H_orig, W_orig = rgb_gt_orig.shape[:2]
    if rgb_gt_orig.max() > 1.0:
        rgb_gt_orig = rgb_gt_orig.astype(np.float32) / 255.0
    else:
        rgb_gt_orig = rgb_gt_orig.astype(np.float32)

    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    frame_indices = [int(f.stem) for f in frame_files]
    last_timestep_frame_idx = frame_indices[min(start_frame_idx + N_WINDOW - 1, len(frame_indices) - 1)]
    joint_state_path = episode_dir / f"{last_timestep_frame_idx:06d}.npy"
    if joint_state_path.exists():
        import mujoco
        from ExoConfigs.umi_so100 import UMI_SO100_CONFIG
        from exo_utils import render_from_camera_pose

        joint_state = np.load(joint_state_path)
        robot_config = UMI_SO100_CONFIG
        mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
        mj_data = mujoco.MjData(mj_model)
        nq = mj_model.nq
        mj_data.qpos[: min(len(joint_state), nq)] = joint_state[:nq]
        mujoco.mj_forward(mj_model, mj_data)

        camera_pose = sample["camera_pose"].numpy()
        cam_K_norm = sample["cam_K_norm"].numpy()
        # Intrinsics at GT image resolution (dataset: row0 *= W_orig, row1 *= H_orig)
        cam_K_orig = np.eye(3)
        cam_K_orig[0, 0] = cam_K_norm[0, 0] * W_orig
        cam_K_orig[0, 2] = cam_K_norm[0, 2] * W_orig
        cam_K_orig[1, 1] = cam_K_norm[1, 1] * H_orig
        cam_K_orig[1, 2] = cam_K_norm[1, 2] * H_orig

        # Full RGB render (for overlay)
        rendered_full = render_from_camera_pose(
            mj_model, mj_data, camera_pose, cam_K_orig, H_orig, W_orig, segmentation=False
        )
        rendered_full_float = (rendered_full / 255.0).astype(np.float32)
        if rendered_full_float.ndim == 2:
            rendered_full_float = np.stack([rendered_full_float] * 3, axis=-1)

        # Segmentation render (for robot mask)
        rendered_seg = render_from_camera_pose(
            mj_model, mj_data, camera_pose, cam_K_orig, H_orig, W_orig, segmentation=True
        )
        # MuJoCo segmentation returns (H, W) with object IDs; 0 = background
        seg_mask = (rendered_seg[:,:,0] > 0)[...,None]

        # 1) Overlay on GT image
        rgb_with_robot = np.clip(0.5 * rgb_gt_orig + 0.5 * rendered_full_float, 0, 1)
        fig_robot, ax_robot = plt.subplots(1, 1, figsize=(6, 6))
        ax_robot.imshow(rgb_with_robot)
        ax_robot.set_title("GT robot (last timestep)")
        ax_robot.axis("off")
        plt.tight_layout()
        plt.savefig(save_dir / "paper_fig_gt_robot.png", dpi=150, bbox_inches="tight")
        print(f"Saved: {save_dir / 'paper_fig_gt_robot.png'} ({H_orig}x{W_orig})")
        plt.close(fig_robot)

        # 2) Robot on white background only (no RGB overlay)
        white_bg = np.ones((H_orig, W_orig, 3), dtype=np.float32)
        robot_on_white = rendered_full_float * seg_mask + white_bg * (1 - seg_mask)
        fig_white, ax_white = plt.subplots(1, 1, figsize=(6, 6))
        ax_white.imshow(robot_on_white)
        ax_white.axis("off")
        plt.tight_layout()
        plt.savefig(save_dir / "paper_fig_gt_robot_white.png", dpi=150, bbox_inches="tight")
        print(f"Saved: {save_dir / 'paper_fig_gt_robot_white.png'} ({H_orig}x{W_orig})")
        plt.show()
        plt.close(fig_white)
    else:
        print(f"GT robot overlay skipped: no joint state at {joint_state_path}")

    # --- 3) 3D volume: unproject (u, v, height) to world X,Y,Z using camera intrinsics and pose ---
    H, W = vol_last.shape[1], vol_last.shape[2]
    intensity = np.maximum(vol_last, 0)  # (N_HEIGHT_BINS, H, W)
    norm = intensity.max()
    if norm <= 0:
        norm = 1.0
    if norm >= 0.8:
        norm = 0.75
    alpha = intensity / norm
    alpha += 0.015
    alpha = np.clip(alpha, 0, 1)

    # Camera: get pose and intrinsics for IMAGE_SIZE
    camera_pose = sample["camera_pose"].numpy()
    cam_K_norm = sample["cam_K_norm"].numpy()
    cam_K = np.eye(3)
    cam_K[0, 0] = cam_K_norm[0, 0] * IMAGE_SIZE
    cam_K[0, 2] = cam_K_norm[0, 2] * IMAGE_SIZE
    cam_K[1, 1] = cam_K_norm[1, 1] * IMAGE_SIZE
    cam_K[1, 2] = cam_K_norm[1, 2] * IMAGE_SIZE

    # Height (m) per bin index
    bin_centers = np.linspace(0.0, 1.0, N_HEIGHT_BINS)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    height_values = bin_centers * (max_h - min_h) + min_h  # (N_HEIGHT_BINS,)

    # Downsample for 3D scatter
    step = max(1, min(H, W) // 32)
    zz, yy, xx = np.meshgrid(
        np.arange(vol_last.shape[0]),
        np.arange(0, H, step),
        np.arange(0, W, step),
        indexing="ij",
    )
    zz = zz.ravel()
    yy = yy.ravel()
    xx = xx.ravel()
    aa = alpha[zz, yy, xx]
    thresh = 0.005
    mask = aa >= thresh
    xx, yy, zz, aa = xx[mask], yy[mask], zz[mask], aa[mask]

    # Unproject each (u, v, height_bin) to world (X, Y, Z)
    pts_3d = []
    alphas = []
    for i in range(len(xx)):
        u, v = float(xx[i]), float(yy[i])
        h_bin = int(zz[i])
        height_m = float(height_values[h_bin])
        pt = recover_3d_from_direct_keypoint_and_height(
            np.array([u, v]), height_m, camera_pose, cam_K
        )
        if pt is not None:
            pts_3d.append(pt)
            alphas.append(aa[i])
    if not pts_3d:
        pts_3d = np.zeros((0, 3))
        alphas = np.array([])
    else:
        pts_3d = np.array(pts_3d)
        alphas = np.array(alphas)

    

    base_rgba = np.array(mcolors.to_rgba("orangered"))
    colors = np.tile(base_rgba, (len(pts_3d), 1))
    colors[:, 3] = np.clip(alphas+.02, 0, 1)

    fig3 = plt.figure(figsize=(8, 6))
    fig3.patch.set_facecolor("white")
    ax3 = fig3.add_subplot(111, projection="3d")
    if len(pts_3d) > 0:
        ax3.scatter(pts_3d[:, 0], pts_3d[:, 1], pts_3d[:, 2], c=colors, s=4, edgecolors="none")
    ax3.set_xticks([])
    ax3.set_yticks([])
    ax3.set_zticks([])
    ax3.set_xlabel("")
    ax3.set_ylabel("")
    ax3.set_zlabel("")
    ax3.set_title("")
    ax3.grid(False)
    ax3.xaxis.pane.fill = False
    ax3.yaxis.pane.fill = False
    ax3.zaxis.pane.fill = False
    ax3.xaxis.pane.set_edgecolor("none")
    ax3.yaxis.pane.set_edgecolor("none")
    ax3.zaxis.pane.set_edgecolor("none")
    ax3.set_facecolor("white")
    plt.savefig(save_dir / "paper_fig_3d_volume.png", dpi=150, bbox_inches="tight", facecolor="white")
    print(f"Saved: {save_dir / 'paper_fig_3d_volume.png'}")
    plt.show()
    plt.close()


def visualize_gt_vs_pred_2d(model, dataset, indices, title, save_path=None):
    """Axes: x = episode (max 5 per row), y = GT vs Pred. Wraps to new row-pairs after every 5 episodes."""
    device = next(model.parameters()).device
    model.eval()

    n_samples = len(indices)
    n_cols = min(MAX_EPISODES_PER_ROW, n_samples)
    n_row_blocks = (n_samples + n_cols - 1) // n_cols  # ceil(n_samples / n_cols)
    n_rows = n_row_blocks * 2  # each block = one GT row + one Pred row

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
    if n_rows == 1:
        axes = axes[np.newaxis, :]
    if n_cols == 1:
        axes = axes[:, np.newaxis]

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        rgb = sample["rgb"].unsqueeze(0).to(device)
        trajectory_2d = sample["trajectory_2d"].cpu().numpy()  # (N_WINDOW, 2)
        start_keypoint_2d = torch.tensor(trajectory_2d[0], device=device, dtype=torch.float32)

        pred_2d, _, _, _ = run_volume_model(model, rgb, start_keypoint_2d, device)
        pred_2d_np = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2)
        rgb_vis = _denorm_rgb(rgb)

        block = i // n_cols
        col = i % n_cols
        row_gt = block * 2
        row_pred = block * 2 + 1

        ax_gt = axes[row_gt, col]
        ax_gt.imshow(rgb_vis)
        ax_gt.plot(trajectory_2d[:, 0], trajectory_2d[:, 1], "w-", linewidth=2, alpha=0.9, label="GT")
        ax_gt.scatter(trajectory_2d[:, 0], trajectory_2d[:, 1], c="white", s=60, marker="o", edgecolors="black", linewidths=1.5, zorder=10)
        ax_gt.set_title(f"Ep {i}", fontsize=11, fontweight="bold")
        if col == 0:
            ax_gt.set_ylabel("GT", fontsize=10, fontweight="bold")
        ax_gt.axis("off")
        ax_gt.legend(loc="upper right", fontsize=8)

        ax_pred = axes[row_pred, col]
        ax_pred.imshow(rgb_vis)
        ax_pred.plot(pred_2d_np[:, 0], pred_2d_np[:, 1], color="lime", linewidth=2, alpha=0.9, label="Pred")
        ax_pred.scatter(pred_2d_np[:, 0], pred_2d_np[:, 1], c="lime", s=60, marker="x", linewidths=2, zorder=10)
        mean_px_err = float(np.mean([np.linalg.norm(trajectory_2d[t] - pred_2d_np[t]) for t in range(N_WINDOW)]))
        ax_pred.set_title(f"Ep {i} (err: {mean_px_err:.1f}px)", fontsize=11, fontweight="bold")
        if col == 0:
            ax_pred.set_ylabel("Pred", fontsize=10, fontweight="bold")
        ax_pred.axis("off")
        ax_pred.legend(loc="upper right", fontsize=8)

    # Hide empty cells in the last block
    for j in range(n_samples, n_cols * n_row_blocks):
        block = j // n_cols
        col = j % n_cols
        axes[block * 2, col].axis("off")
        axes[block * 2 + 1, col].axis("off")

    plt.suptitle(title, fontsize=14, fontweight="bold")
    plt.subplots_adjust(hspace=0.08, wspace=0.04, top=0.94)
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        print(f"Saved: {save_path}")
    else:
        plt.show()
    plt.close()


def compute_metrics(model, dataset, indices=None):
    """Compute pixel / height / gripper errors over samples (optionally only at indices)."""
    device = next(model.parameters()).device
    model.eval()
    if indices is None:
        indices = list(range(len(dataset)))

    total_pixel = 0.0
    total_height = 0.0
    total_gripper = 0.0
    n = 0
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            rgb = sample["rgb"].unsqueeze(0).to(device)
            trajectory_2d = sample["trajectory_2d"]
            trajectory_3d = sample["trajectory_3d"]
            trajectory_gripper = sample["trajectory_gripper"]
            start_keypoint_2d = trajectory_2d[0].to(device)

            pred_2d, pred_height, pred_gripper, _ = run_volume_model(model, rgb, start_keypoint_2d, device)
            pixel_errs = []
            height_errs = []
            gripper_errs = []
            for t in range(N_WINDOW):
                pred_xy = pred_2d[0, t].cpu()
                target_xy = trajectory_2d[t]
                pixel_errs.append(torch.norm(pred_xy - target_xy).item())
                height_errs.append(torch.abs(pred_height[0, t].cpu() - trajectory_3d[t, 2]).item())
                gripper_errs.append(torch.abs(pred_gripper[0, t].cpu() - trajectory_gripper[t]).item())
            total_pixel += np.mean(pixel_errs)
            total_height += np.mean(height_errs)
            total_gripper += np.mean(gripper_errs)
            n += 1

    return {
        "n_samples": n,
        "avg_pixel_error_px": total_pixel / n if n else 0,
        "avg_height_error_mm": (total_height / n * 1000) if n else 0,
        "avg_gripper_abs_error": (total_gripper / n) if n else 0,
    }


def main():
    parser = argparse.ArgumentParser(description="Visualize GT vs predicted 2D trajectory (volume model)")
    parser.add_argument("--checkpoint", type=str, default="volume_dino_tracks/checkpoints/volume_dino_tracks/latest.pth", help="Path to model checkpoint")
    parser.add_argument("--dataset_root", type=str, default="scratch/parsed_school_long_recap", help="Root of real dataset")
    parser.add_argument("--episode", type=str, default=None, help="Episode name (e.g. episode_001)")
    parser.add_argument("--start_frame", type=int, default=None, help="Start frame index (e.g. 0). With --episode, show only this sample.")
    parser.add_argument("--max_episodes", type=int, default=1, help="If no episode/start_frame, load first K episodes")
    parser.add_argument("--num_samples", type=int, default=6, help="Number of samples to visualize when not using episode+start_frame")
    parser.add_argument("--save_dir", type=str, default="scratch/real_eval_vis", help="Directory to save figures")
    parser.add_argument("--metrics_only", action="store_true", help="Only print metrics, no plot")
    parser.add_argument("--paper", action="store_true", help="Generate paper figure: RGB, heatmap grid (last timestep), 3D volume with transparency")
    parser.add_argument("--debug", action="store_true", help="Drop into pdb before viz")
    args = parser.parse_args()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    if "min_height" in checkpoint and "max_height" in checkpoint:
        model_module.MIN_HEIGHT = float(checkpoint["min_height"])
        model_module.MAX_HEIGHT = float(checkpoint["max_height"])
        print(f"Height range: [{model_module.MIN_HEIGHT*1000:.2f}, {model_module.MAX_HEIGHT*1000:.2f}] mm")
    if "min_gripper" in checkpoint and "max_gripper" in checkpoint:
        model_module.MIN_GRIPPER = float(checkpoint["min_gripper"])
        model_module.MAX_GRIPPER = float(checkpoint["max_gripper"])
        print(f"Gripper range: [{model_module.MIN_GRIPPER:.3f}, {model_module.MAX_GRIPPER:.3f}]")
    model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    model = model.to(device)
    model.eval()
    print(f"Loaded epoch {checkpoint.get('epoch', '?')}")

    print(f"\nLoading dataset: {args.dataset_root}")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE,
        episode=args.episode,
        max_episodes=None if args.episode else args.max_episodes,
    )
    print(f"Total samples: {len(dataset)}")

    if args.episode and args.start_frame is not None:
        idx = None
        for i, (ep_dir, frame_idx) in enumerate(dataset.samples):
            if ep_dir.name == args.episode and frame_idx == args.start_frame:
                idx = i
                break
        if idx is None:
            raise ValueError(f"No sample with episode={args.episode} and start_frame={args.start_frame}")
        indices = [idx]
        title = f"GT vs Pred 2D — {args.episode} frame {args.start_frame}"
    else:
        num_samples = min(args.num_samples, len(dataset))
        indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int).tolist()
        title = "GT vs Pred 2D — sample sweep"

    print("\nComputing metrics...")
    metrics = compute_metrics(model, dataset, indices)
    print(f"Samples: {metrics['n_samples']}")
    print(f"Avg Pixel Error: {metrics['avg_pixel_error_px']:.2f} px")
    print(f"Avg Height Error: {metrics['avg_height_error_mm']:.3f} mm")
    print(f"Avg Gripper Abs Error: {metrics['avg_gripper_abs_error']:.4f}")

    if not args.metrics_only:
        if args.debug:
            import pdb; pdb.set_trace()
        save_dir = Path(args.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        if args.paper:
            sample_idx = indices[0]
            generate_paper_figure(model, dataset, sample_idx, save_dir, close_after_show=True)
        else:
            save_path = None#save_dir / "gt_vs_pred_2d.png"
            if args.episode and args.start_frame is not None:
                save_path = save_dir / f"gt_vs_pred_{args.episode}_frame{args.start_frame}.png"
            visualize_gt_vs_pred_2d(model, dataset, indices, title, save_path=save_path)


if __name__ == "__main__":
    main()
