"""Open-loop AR policy eval on demo replay.

For each demo frame t, run the AR policy on cached past-H patches + EEF history and
record the predicted next-EEF pixel. Compute:
  - per-step pixel error vs GT
  - trajectory jerk (mean ‖ΔΔp‖ on the predicted xy stream)
  - free-space coverage: fraction of GT trajectory points within ε pixels of predicted

This is a cheap diagnostic — no env interaction. Validates whether the AR architecture
fixes the failure modes that motivated the refactor (skip-to-target, jitter) on real
demo trajectories, even before the closed-loop eval is wired up.

Usage:
  cd /data/cameron/para/libero
  CUDA_VISIBLE_DEVICES=9 python eval_ar_open_loop.py \
    --checkpoint checkpoints/ar_libero_spatial_t0_h8_v1/best.pth \
    --model_variant v1 \
    --cache_root /data/libero/parsed_libero \
    --benchmark libero_spatial --task_id 0 \
    --n_demos 5 --output_dir out/eval_ar_v1
"""
import argparse, io, os, sys
from pathlib import Path

import cv2
import numpy as np
import torch
from tqdm import tqdm

sys.path.insert(0, os.path.dirname(__file__))


def load_v1_model(ckpt_path, device, history_len=8, grid_size=56):
    from model_autoregressive import ARTransformerPolicy
    model = ARTransformerPolicy(history_len=history_len, grid_size=grid_size, freeze_backbone=True).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    model.load_state_dict(sd, strict=False)
    model.eval()
    return model, history_len, grid_size


def load_v2_model(ckpt_path, device, history_len=8, grid_size=56):
    from model_autoregressive_v2 import ARTransformerPolicyV2
    model = ARTransformerPolicyV2(history_len=history_len, grid_size=grid_size, freeze_backbone=True).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    model.load_state_dict(sd, strict=False)
    model.eval()
    return model, history_len, grid_size


def grid_idx_to_pixel_numpy(idx, image_size, grid_size):
    gy = idx // grid_size
    gx = idx %  grid_size
    cell = image_size / grid_size
    return np.array([(gx + 0.5) * cell, (gy + 0.5) * cell], dtype=np.float64)


def load_demo(demo_dir, image_size=448):
    frames_dir = demo_dir / "frames"
    frame_paths = sorted(frames_dir.glob("*.png"))
    pix_uv = np.load(demo_dir / "pix_uv.npy")
    T = len(frame_paths)
    imgs = []
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    for p in frame_paths:
        bgr = cv2.imread(str(p))
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        if rgb.shape[0] != image_size:
            rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
        rgb_n = (rgb.transpose(2, 0, 1) - mean) / std
        imgs.append(torch.from_numpy(rgb_n).float())
    imgs = torch.stack(imgs, dim=0)                  # (T, 3, H, W)
    pix_uv = np.clip(pix_uv[:T], 0, image_size - 1).astype(np.float32)
    return imgs, pix_uv, T


@torch.no_grad()
def predict_trajectory_v1(model, imgs, pix_uv, history_len, grid_size, image_size, device):
    """Run AR policy in open-loop: at each step t, predict next-EEF using past H frames.
    Predictions made for t in [H, T-1] (predict EEF[t] given frames[t-H:t])."""
    T = imgs.shape[0]
    pred = np.zeros_like(pix_uv)
    pred[:history_len] = pix_uv[:history_len]  # no prediction for first H frames
    eef_t = torch.from_numpy(pix_uv).float().to(device)
    for t in range(history_len, T):
        hist_imgs = imgs[t - history_len : t].unsqueeze(0).to(device)
        hist_eef  = eef_t[t - history_len : t].unsqueeze(0)
        logits = model(hist_imgs, hist_eef)
        idx = int(logits.argmax(dim=-1).item())
        pred[t] = grid_idx_to_pixel_numpy(idx, image_size, grid_size)
    return pred


@torch.no_grad()
def predict_trajectory_v2(model, imgs, pix_uv, history_len, grid_size, image_size, device):
    """Open-loop with the v2 cached architecture. Single DINO pass over the entire demo,
    then loop ARHead over time. ~5× faster than v1 path."""
    T = imgs.shape[0]
    pred = np.zeros_like(pix_uv)
    pred[:history_len] = pix_uv[:history_len]
    # Stage A: one DINO call over the whole demo (chunk to keep memory bounded).
    CHUNK = 16
    all_patches = []
    for s in range(0, T, CHUNK):
        chunk = imgs[s : s + CHUNK].unsqueeze(0).to(device)        # (1, n, 3, H, W)
        patches_chunk = model.patch_encoder(chunk)[0]              # (n, Np, D)
        all_patches.append(patches_chunk)
    patches = torch.cat(all_patches, dim=0).unsqueeze(0)           # (1, T, Np, D)
    eef_t = torch.from_numpy(pix_uv).float().to(device)
    for t in range(history_len, T):
        hist_p = patches[:, t - history_len : t]
        hist_e = eef_t[t - history_len : t].unsqueeze(0)
        anchor = hist_e[:, -1]
        logits = model.ar_head(hist_p, hist_e, anchor, image_size)
        idx = int(logits.argmax(dim=-1).item())
        pred[t] = grid_idx_to_pixel_numpy(idx, image_size, grid_size)
    return pred


def trajectory_metrics(pred, gt, history_len, eps_px=15.0):
    """Compute per-step pixel error, jerk, and free-space coverage on the predictable part."""
    T = pred.shape[0]
    p = pred[history_len:]
    g = gt  [history_len:]
    pix_err = np.linalg.norm(p - g, axis=-1)                        # (T-H,)
    # Jerk: mean ‖p_{t+1} - 2 p_t + p_{t-1}‖ over the predicted stream
    if p.shape[0] >= 3:
        ddp = p[2:] - 2 * p[1:-1] + p[:-2]
        jerk = float(np.linalg.norm(ddp, axis=-1).mean())
    else:
        jerk = float("nan")
    gt_jerk = float(np.linalg.norm(g[2:] - 2 * g[1:-1] + g[:-2], axis=-1).mean()) if g.shape[0] >= 3 else float("nan")
    # Free-space coverage: for each GT point, distance to nearest predicted point
    # (high coverage = predicted trajectory covers GT, not just endpoint-skipping)
    if p.shape[0] > 0:
        dists = np.linalg.norm(g[:, None, :] - p[None, :, :], axis=-1)   # (N_gt, N_pred)
        nearest = dists.min(axis=1)                                  # (N_gt,)
        coverage = float((nearest < eps_px).mean())
    else:
        coverage = float("nan")
    return {
        "mean_pix_err":   float(pix_err.mean()),
        "median_pix_err": float(np.median(pix_err)),
        "max_pix_err":    float(pix_err.max()),
        "pred_jerk":      jerk,
        "gt_jerk":        gt_jerk,
        "coverage_eps15": coverage,
        "n_steps":        int(p.shape[0]),
    }


def render_trajectory_overlay(imgs, pix_uv, pred, history_len, image_size, out_path):
    """Save a single PNG with the full predicted vs GT trajectory drawn on the LAST frame."""
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    img = (imgs[-1].cpu().numpy() * std + mean).transpose(1, 2, 0)
    img = np.clip(img, 0, 1)
    fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=100)
    ax.imshow(img)
    ax.plot(pix_uv[:, 0],  pix_uv[:, 1], '-', color="lime",  linewidth=1.5, alpha=0.85, label="GT")
    ax.plot(pred[history_len:, 0], pred[history_len:, 1], '-', color="red", linewidth=1.5, alpha=0.85, label="pred (AR)")
    ax.scatter(pix_uv[0, 0], pix_uv[0, 1], s=80, c="white", marker="o", label="start")
    ax.scatter(pix_uv[-1, 0], pix_uv[-1, 1], s=120, c="lime", marker="*", label="end (GT)")
    ax.set_xlim(0, image_size); ax.set_ylim(image_size, 0)
    ax.set_title(out_path.stem, fontsize=10); ax.legend(fontsize=9); ax.axis("off")
    fig.tight_layout()
    fig.savefig(out_path, bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint",    type=str, required=True)
    parser.add_argument("--model_variant", type=str, default="v1", choices=["v1", "v2"])
    parser.add_argument("--cache_root",    type=str, required=True)
    parser.add_argument("--benchmark",     type=str, default="libero_spatial")
    parser.add_argument("--task_id",       type=int, default=0)
    parser.add_argument("--n_demos",       type=int, default=5)
    parser.add_argument("--history_len",   type=int, default=8)
    parser.add_argument("--grid_size",     type=int, default=56)
    parser.add_argument("--image_size",    type=int, default=448)
    parser.add_argument("--eps_px",        type=float, default=15.0)
    parser.add_argument("--output_dir",    type=str, default="out/eval_ar")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)

    if args.model_variant == "v1":
        model, H, G = load_v1_model(args.checkpoint, device, args.history_len, args.grid_size)
        predict_fn = predict_trajectory_v1
    else:
        model, H, G = load_v2_model(args.checkpoint, device, args.history_len, args.grid_size)
        predict_fn = predict_trajectory_v2
    print(f"Loaded {args.model_variant} model from {args.checkpoint}")

    bench_root = Path(args.cache_root) / args.benchmark / f"task_{args.task_id}"
    demo_dirs = sorted(bench_root.glob("demo_*"))[: args.n_demos]
    print(f"Evaluating on {len(demo_dirs)} demos")

    all_metrics = []
    for demo_dir in tqdm(demo_dirs, desc="Demos"):
        imgs, pix_uv, T = load_demo(demo_dir, args.image_size)
        if T < H + 3:
            print(f"  skip {demo_dir.name}: too short")
            continue
        pred = predict_fn(model, imgs, pix_uv, H, G, args.image_size, device)
        m = trajectory_metrics(pred, pix_uv, H, eps_px=args.eps_px)
        m["demo"] = demo_dir.name
        all_metrics.append(m)
        render_trajectory_overlay(imgs, pix_uv, pred, H, args.image_size,
                                  out_dir / f"{demo_dir.name}_overlay.png")

    print("\n=== Per-demo metrics ===")
    for m in all_metrics:
        print(f"  {m['demo']}: px_err={m['mean_pix_err']:5.1f}  jerk_pred={m['pred_jerk']:5.2f}  "
              f"jerk_gt={m['gt_jerk']:5.2f}  cov@15px={m['coverage_eps15']:.2f}  n={m['n_steps']}")

    if all_metrics:
        mean = lambda k: float(np.mean([m[k] for m in all_metrics]))
        print("\n=== Aggregate ===")
        print(f"  mean pixel error  : {mean('mean_pix_err'):6.2f} px")
        print(f"  median pixel error: {mean('median_pix_err'):6.2f} px")
        print(f"  pred jerk (px)    : {mean('pred_jerk'):6.2f}  (gt: {mean('gt_jerk'):.2f})")
        print(f"  coverage @ 15px   : {mean('coverage_eps15'):.2%}")
        print(f"  overlays saved to : {out_dir}")
    return all_metrics


if __name__ == "__main__":
    main()
