"""
Minimal: start-frame -> video sample using a pretrained UVA checkpoint (Libero10, UMI, or PushT).

Writes an mp4 to ./sample_videos/sample_from_start_frame.mp4

Usage:
  python scripts/download_ckpts.py   # download libero10.ckpt + umi_multitask.ckpt
  PYTHONNOUSERSITE=1 python scripts/sample_video_from_start_frame.py --checkpoint libero10
  PYTHONNOUSERSITE=1 python scripts/sample_video_from_start_frame.py --checkpoint umi
  (PYTHONNOUSERSITE=1 avoids loading local accelerate that expects a newer huggingface_hub.)
"""

import argparse
import os
import sys
from pathlib import Path

# Compat: transformers 4.28 passes use_auth_token= to hf_hub_download; newer huggingface_hub
# only accepts token=. Patch before any imports that load transformers (run with
# PYTHONNOUSERSITE=1 to avoid local accelerate needing split_torch_state_dict_into_shards).
def _patch_hf_hub():
    try:
        import huggingface_hub
        _orig = huggingface_hub.hf_hub_download
        def _patched(*args, use_auth_token=None, **kwargs):
            if use_auth_token is not None:
                kwargs.setdefault("token", use_auth_token)
            return _orig(*args, **kwargs)
        huggingface_hub.hf_hub_download = _patched
    except Exception:
        pass

_patch_hf_hub()

import torch
import matplotlib.pyplot as plt

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(REPO_ROOT / "src"))

import dill
import hydra
from omegaconf import open_dict

from unified_video_action.workspace.train_unified_video_action_workspace import (
    TrainUnifiedVideoActionWorkspace,
)
from unified_video_action.utils.data_utils import decode_from_sample_autoregressive

# Named checkpoints (must exist under checkpoints/). Run scripts/download_ckpts.py to fetch.
CHECKPOINT_ALIASES = {"pusht": "pusht.ckpt", "libero10": "libero10.ckpt", "umi": "umi_multitask.ckpt"}


def _ensure_checkpoint_downloaded(ckpt_path: Path) -> None:
    if ckpt_path.is_file():
        return
    name = ckpt_path.name
    if name not in ("libero10.ckpt", "umi_multitask.ckpt"):
        return
    print(f"Checkpoint not found: {ckpt_path}. Running download script...")
    import subprocess
    subprocess.run(
        [sys.executable, str(REPO_ROOT / "scripts" / "download_ckpts.py"), "-c", name],
        check=True,
    )


def load_policy_from_checkpoint(ckpt_path: Path, device: torch.device):
    """Load policy and cfg from a full checkpoint (like eval_sim.py)."""
    with open(ckpt_path, "rb") as f:
        payload = torch.load(f, map_location=device, pickle_module=dill)
    cfg = payload["cfg"]
    with open_dict(cfg):
        cfg.output_dir = str(REPO_ROOT / "tmp_video_sample")
    cls = hydra.utils.get_class(cfg.model._target_)
    workspace = cls(cfg, output_dir=cfg.output_dir)
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)
    policy = workspace.ema_model if workspace.ema_model is not None else workspace.model
    policy.to(device)
    policy.eval()
    return policy, cfg


@torch.no_grad()
def sample_video_from_start_frame(policy, cfg, x0: torch.Tensor):
    """
    x0: (B, 3, 256, 256) in [-1, 1]
    returns: (B, T, 3, 256, 256) in [-1, 1]
    """
    b, c, h, w = x0.shape
    assert (c, h, w) == (3, 256, 256), f"expected (B,3,256,256), got {x0.shape}"

    posterior = policy.vae_model.encode(x0.float())
    z0 = posterior.sample().mul_(0.2325)

    n_frames = int(policy.model.n_frames)
    cond = z0.unsqueeze(1).repeat(1, n_frames, 1, 1, 1)

    tokens, _ = policy.model.sample_tokens(
        bsz=b,
        cond=cond,
        text_latents=None,
        num_iter=cfg.model.policy.autoregressive_model_params.num_iter,
        cfg=cfg.model.policy.autoregressive_model_params.cfg,
        cfg_schedule=cfg.model.policy.autoregressive_model_params.cfg_schedule,
        temperature=cfg.model.policy.autoregressive_model_params.temperature,
        history_nactions=None,
        nactions=None,
        proprioception_input={},
        task_mode="video_model",
        vae_model=policy.vae_model,
    )

    pred = decode_from_sample_autoregressive(policy.vae_model, tokens / 0.2325)
    pred = pred.view(b, n_frames, 3, 256, 256)
    return pred


def main():
    repo_root = REPO_ROOT
    os.chdir(repo_root)

    p = argparse.ArgumentParser(description="UVA start-frame -> video sample")
    p.add_argument("--checkpoint", "-c", type=str, default="libero10",
                   help="Checkpoint: libero10, umi, pusht, or path to .ckpt (default: libero10)")
    p.add_argument("--video", "-v", type=str, default=None,
                   help="Path to video for start frame (default: tmpvids/24400334.mp4)")
    args = p.parse_args()

    # Resolve checkpoint path
    if args.checkpoint.lower() in CHECKPOINT_ALIASES:
        ckpt_name = CHECKPOINT_ALIASES[args.checkpoint.lower()]
        ckpt_path = repo_root / "checkpoints" / ckpt_name
        _ensure_checkpoint_downloaded(ckpt_path)
    else:
        ckpt_path = Path(args.checkpoint)
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    policy, cfg = load_policy_from_checkpoint(ckpt_path, device)

    # Start frame: middle frame of video
    if 0:
        video_path = Path(args.video) if args.video else (Path(__file__).resolve().parents[2] / "tmpvids" / "24400334.mp4")
        if not video_path.is_file():
            raise FileNotFoundError(f"Video not found: {video_path}")
        import imageio.v3 as iio
        import numpy as np
        frames = [f for f in iio.imiter(str(video_path), plugin="FFMPEG")]
        mid_idx = len(frames) // 2
        frame = torch.from_numpy(frames[mid_idx].copy()).float() / 255.0
        frame = frame.permute(2, 0, 1).unsqueeze(0)
    else:
        frame=plt.imread("fake_libero_img.png")[:,:,:3]
        frame = torch.from_numpy(frame).float()
        frame = frame.permute(2, 0, 1).unsqueeze(0)
    frame = torch.nn.functional.interpolate(frame, size=(256, 256), mode="bilinear", align_corners=False)
    x0 = (frame * 2.0 - 1.0).to(device)

    pred = sample_video_from_start_frame(policy, cfg, x0=x0)

    out_dir = repo_root / "sample_videos"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_name = f"sample_from_start_frame_{ckpt_path.stem}.mp4"
    out_path = out_dir / out_name

    import torchvision
    video = ((pred[0] + 1.0) / 2.0).clamp(0, 1)
    frames_np = (video.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype("uint8")
    torchvision.io.write_video(str(out_path), torch.from_numpy(frames_np), fps=10)

    print(f"Saved video to {out_path}")


if __name__ == "__main__":
    main()
