"""
Start-frame -> video using simple_uva (same output as sample_video_from_start_frame.py).
Run from repo root with: PYTHONNOUSERSITE=1 python scripts/sample_video_from_start_frame_simple.py --checkpoint libero10
"""

import argparse
import os
import sys
from pathlib import Path

# HF compat before any transformers import
def _patch_hf_hub():
    try:
        import huggingface_hub
        _orig = huggingface_hub.hf_hub_download
        def _patched(*args, use_auth_token=None, **kwargs):
            if use_auth_token is not None:
                kwargs.setdefault("token", use_auth_token)
            return _orig(*args, **kwargs)
        huggingface_hub.hf_hub_download = _patched
    except Exception:
        pass
_patch_hf_hub()

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(REPO_ROOT / "src"))

import torch
import matplotlib.pyplot as plt

from simple_uva import load_policy_from_checkpoint, sample_video_from_start_frame

CHECKPOINT_ALIASES = {"pusht": "pusht.ckpt", "libero10": "libero10.ckpt", "umi": "umi_multitask.ckpt"}


def _ensure_checkpoint_downloaded(ckpt_path: Path) -> None:
    if ckpt_path.is_file():
        return
    name = ckpt_path.name
    if name not in ("libero10.ckpt", "umi_multitask.ckpt"):
        return
    print(f"Checkpoint not found: {ckpt_path}. Running download script...")
    import subprocess
    subprocess.run(
        [sys.executable, str(REPO_ROOT / "scripts" / "download_ckpts.py"), "-c", name],
        check=True,
    )


def main():
    os.chdir(REPO_ROOT)

    p = argparse.ArgumentParser(description="Simple UVA: start-frame -> video")
    p.add_argument("--checkpoint", "-c", type=str, default="libero10",
                   help="libero10, umi, pusht, or path to .ckpt")
    p.add_argument("--video", "-v", type=str, default=None, help="Video path for start frame")
    p.add_argument("--image", "-i", type=str, default="fake_libero_img2.png",
                   help="Image path for start frame (default: fake_libero_img.png)")
    args = p.parse_args()

    if args.checkpoint.lower() in CHECKPOINT_ALIASES:
        ckpt_name = CHECKPOINT_ALIASES[args.checkpoint.lower()]
        ckpt_path = REPO_ROOT / "checkpoints" / ckpt_name
        _ensure_checkpoint_downloaded(ckpt_path)
    else:
        ckpt_path = Path(args.checkpoint)
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    policy, cfg = load_policy_from_checkpoint(ckpt_path, device, repo_root=REPO_ROOT)

    # Start frame: image or middle frame of video
    if args.video:
        video_path = Path(args.video)
        if not video_path.is_file():
            raise FileNotFoundError(f"Video not found: {video_path}")
        import imageio.v3 as iio
        import numpy as np
        frames = [f for f in iio.imiter(str(video_path), plugin="FFMPEG")]
        mid_idx = len(frames) // 2
        frame = torch.from_numpy(frames[mid_idx].copy()).float() / 255.0
        frame = frame.permute(2, 0, 1).unsqueeze(0)
    else:
        img_path = REPO_ROOT / args.image
        if not img_path.is_file():
            raise FileNotFoundError(f"Image not found: {img_path}")
        frame = plt.imread(str(img_path))
        if frame.ndim == 2:
            frame = frame[:, :, None].repeat(3, axis=2)
        frame = frame[:, :, :3]
        frame = torch.from_numpy(frame).float()
        frame = frame.permute(2, 0, 1).unsqueeze(0)
    frame = torch.nn.functional.interpolate(frame, size=(256, 256), mode="bilinear", align_corners=False)
    x0 = (frame * 2.0 - 1.0).to(device)

    pred = sample_video_from_start_frame(policy, cfg, x0=x0)

    out_dir = REPO_ROOT / "sample_videos"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_name = f"sample_from_start_frame_simple_{ckpt_path.stem}.mp4"
    out_path = out_dir / out_name

    import torchvision
    video = ((pred[0] + 1.0) / 2.0).clamp(0, 1)
    frames_np = (video.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype("uint8")
    torchvision.io.write_video(str(out_path), torch.from_numpy(frames_np), fps=10)

    print(f"Saved video to {out_path}")


if __name__ == "__main__":
    main()
