"""pretrained_eval.py — Evaluate a PARA model that was pretrained on point tracks and fine-tuned on LIBERO.

This script is functionally identical to eval.py but tailored for the pretrained pipeline:
  1. Point-track pretraining on RTX video data (trains DINO backbone on 2D heatmap tracking)
  2. PARA fine-tuning on LIBERO demos (trains full model with volume + gripper + rotation heads)

The resulting checkpoint is a standard TrajectoryHeatmapPredictor — the only difference from
a from-scratch model is that the DINO backbone was initialized from point-track pretraining
rather than ImageNet pretrained weights.

Usage:
    python libero/pretrained_eval.py \
        --checkpoint libero/checkpoints/pretrained_para_all_tasks/best.pth \
        --benchmark libero_spatial \
        --task_id 0 \
        --n_episodes 20 \
        --run_name pretrained_para_alltask \
        --save_video --save_vis

    # With baseline comparison:
    python libero/pretrained_eval.py \
        --checkpoint libero/checkpoints/pretrained_para_all_tasks/best.pth \
        --baseline_checkpoint libero/checkpoints/all_task_training/best.pth \
        --baseline_run_name para_alltask \
        --benchmark libero_spatial \
        --task_id 0 \
        --n_episodes 20 \
        --run_name pretrained_para_alltask \
        --save_video --save_vis
"""

import argparse
import json
import os
import sys
from pathlib import Path

import numpy as np
import torch

sys.path.insert(0, os.path.dirname(__file__))

# Re-use eval infrastructure from eval.py
from eval import run_eval, IMAGE_SIZE
import model as model_module


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate pretrained (point-track -> PARA) model in LIBERO simulation"
    )
    parser.add_argument("--model_type", type=str, default="para",
                        choices=["para"],
                        help="Model architecture (pretrained pipeline uses PARA)")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to fine-tuned PARA checkpoint")
    parser.add_argument("--baseline_checkpoint", type=str, default="",
                        help="Optional: path to from-scratch PARA checkpoint for comparison")
    parser.add_argument("--run_name", type=str, default="pretrained_para_alltask",
                        help="Descriptive name for this eval run (used in output directory)")
    parser.add_argument("--baseline_run_name", type=str, default="para_alltask",
                        help="Descriptive name for the baseline run (used in output directory)")
    parser.add_argument("--benchmark", type=str, default="libero_spatial",
                        help="LIBERO benchmark name")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--n_episodes", type=int, default=20,
                        help="Number of rollout episodes")
    parser.add_argument("--max_steps", type=int, default=300,
                        help="Max env steps per episode")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--out_root", type=str, default="libero/out",
                        help="Root output directory (run_name is appended)")
    parser.add_argument("--save_video", action="store_true",
                        help="Save per-episode MP4 with heatmap overlay")
    parser.add_argument("--save_vis", action="store_true",
                        help="Save per-replan-step visualization strips (all N_WINDOW heatmaps) as PNGs")
    parser.add_argument("--video_fps", type=int, default=10)
    parser.add_argument("--clip_embeddings_dir", type=str, default="/data/libero/parsed_libero")
    args = parser.parse_args()

    print("=" * 60)
    print("  PRETRAINED PARA EVALUATION")
    print("  Pipeline: Point-track pretraining -> PARA fine-tuning")
    print(f"  Run name: {args.run_name}")
    print("=" * 60)

    # Build eval args for the pretrained model
    eval_args = argparse.Namespace(
        model_type=args.model_type,
        checkpoint=args.checkpoint,
        benchmark=args.benchmark,
        task_id=args.task_id,
        camera=args.camera,
        n_episodes=args.n_episodes,
        max_steps=args.max_steps,
        seed=args.seed,
        out_dir=str(Path(args.out_root) / args.run_name),
        save_video=args.save_video,
        save_vis=args.save_vis,
        video_fps=args.video_fps,
        clip_embeddings_dir=args.clip_embeddings_dir,
    )

    print(f"\n[Pretrained] Evaluating: {args.checkpoint}")
    print(f"  Output: {eval_args.out_dir}")
    pretrained_results = run_eval(eval_args)

    # --- Optionally evaluate baseline for comparison ---
    if args.baseline_checkpoint and os.path.exists(args.baseline_checkpoint):
        print(f"\n{'=' * 60}")
        print(f"[Baseline] Evaluating from-scratch model: {args.baseline_checkpoint}")
        print(f"  Run name: {args.baseline_run_name}")
        print(f"{'=' * 60}")

        baseline_eval_args = argparse.Namespace(
            model_type=args.model_type,
            checkpoint=args.baseline_checkpoint,
            benchmark=args.benchmark,
            task_id=args.task_id,
            camera=args.camera,
            n_episodes=args.n_episodes,
            max_steps=args.max_steps,
            seed=args.seed,
            out_dir=str(Path(args.out_root) / args.baseline_run_name),
            save_video=args.save_video,
            save_vis=args.save_vis,
            video_fps=args.video_fps,
            clip_embeddings_dir=args.clip_embeddings_dir,
        )

        print(f"  Output: {baseline_eval_args.out_dir}")
        baseline_results = run_eval(baseline_eval_args)

        # --- Print comparison ---
        print(f"\n{'=' * 60}")
        print("  COMPARISON: Pretrained vs From-Scratch")
        print(f"{'=' * 60}")
        print(f"  Benchmark:       {args.benchmark}")
        print(f"  Task:            {args.task_id}")
        print(f"  Episodes:        {args.n_episodes}")
        print(f"")
        pt_sr = pretrained_results["success_rate"]
        bl_sr = baseline_results["success_rate"]
        delta = pt_sr - bl_sr
        print(f"  Pretrained ({args.run_name}):   {pt_sr * 100:.1f}% success  (avg {pretrained_results['avg_steps']:.1f} steps)")
        print(f"  Baseline ({args.baseline_run_name}):     {bl_sr * 100:.1f}% success  (avg {baseline_results['avg_steps']:.1f} steps)")
        print(f"  Delta:           {delta * 100:+.1f}% {'(pretrained better)' if delta > 0 else '(from-scratch better)' if delta < 0 else '(tied)'}")
        print(f"{'=' * 60}")

        # Save comparison JSON
        comparison = {
            "benchmark": args.benchmark,
            "task_id": args.task_id,
            "n_episodes": args.n_episodes,
            "pretrained": {
                "run_name": args.run_name,
                "checkpoint": args.checkpoint,
                "success_rate": pt_sr,
                "avg_steps": pretrained_results["avg_steps"],
            },
            "baseline": {
                "run_name": args.baseline_run_name,
                "checkpoint": args.baseline_checkpoint,
                "success_rate": bl_sr,
                "avg_steps": baseline_results["avg_steps"],
            },
            "delta_success_rate": delta,
        }
        comp_dir = Path(args.out_root) / "comparisons"
        comp_dir.mkdir(parents=True, exist_ok=True)
        comp_path = comp_dir / f"{args.run_name}_vs_{args.baseline_run_name}_{args.benchmark}_task{args.task_id}.json"
        with open(comp_path, "w") as f:
            json.dump(comparison, f, indent=2)
        print(f"Comparison saved -> {comp_path}")
    elif args.baseline_checkpoint:
        print(f"\nWARNING: baseline checkpoint not found: {args.baseline_checkpoint}")


if __name__ == "__main__":
    main()
