"""Evaluate video generation model on LIBERO episodes.

For each episode, stride through frames every 4 steps.  At each start frame,
run the video model (conditioned on that single frame) to predict 4 frames,
then save a side-by-side [GT | PRED] comparison PNG.

Output layout:
  libero/vid_evals/<task>/<demo>/GT_start_frame_XXXX.png
  libero/vid_evals/<task>/<demo>/PRED_start_frame_XXXX.png
"""

import argparse
import sys
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.io import read_image
from einops import rearrange
from tqdm import tqdm

# Add video_training codebase to path
VIDEO_TRAINING_ROOT = Path(__file__).resolve().parent.parent / "video_training" / "unified_video_action"
sys.path.insert(0, str(VIDEO_TRAINING_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from simple_uva.dataset import NUM_FRAMES

LATENT_SCALE = 0.2325
MODEL_IMG_SIZE = 256  # video model operates at 256x256


def build_vae(vae_ckpt: str, device: torch.device):
    ckpt_path = Path(vae_ckpt)
    if not ckpt_path.is_absolute():
        ckpt_path = VIDEO_TRAINING_ROOT / vae_ckpt
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device).eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def load_model(checkpoint: str, device: torch.device):
    model = mar_base_video_only(
        img_size=MODEL_IMG_SIZE,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
    ).to(device)

    ckpt_path = Path(checkpoint)
    if not ckpt_path.is_absolute():
        ckpt_path = VIDEO_TRAINING_ROOT / checkpoint

    try:
        import dill
        payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
    except Exception:
        payload = torch.load(ckpt_path, map_location=device, weights_only=False)

    # Handle both checkpoint formats
    if "state_dicts" in payload:
        # libero10.ckpt format: state_dicts.ema_model with "model." prefix
        sd = payload["state_dicts"].get("ema_model") or payload["state_dicts"].get("model")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
    elif "model" in payload:
        # simple_uva training format: direct model state dict
        model_sd = payload["model"]
    else:
        raise KeyError(f"Unrecognized checkpoint format: {list(payload.keys())}")

    model.load_state_dict(model_sd, strict=False)
    model.eval()
    print(f"Loaded model from {ckpt_path}")
    return model


def load_frame(path: Path, size: int = MODEL_IMG_SIZE) -> torch.Tensor:
    """Load a single PNG frame, resize to (size, size), return (C, H, W) in [-1, 1]."""
    img = read_image(str(path)).float() / 255.0  # (C, H, W) in [0, 1]
    if img.shape[0] == 4:
        img = img[:3]
    img = F.interpolate(img.unsqueeze(0), size=(size, size), mode="bilinear", align_corners=False).squeeze(0)
    return img * 2.0 - 1.0


def tensor_to_uint8(t: torch.Tensor) -> np.ndarray:
    """(C, H, W) in [-1, 1] → (H, W, C) uint8."""
    return ((t.clamp(-1, 1) + 1.0) / 2.0 * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)


@torch.no_grad()
def generate_video(model, vae, start_frame: torch.Tensor, device: torch.device, num_iter: int = 64) -> torch.Tensor:
    """Given a start frame (C, H, W) in [-1,1], generate NUM_FRAMES predicted frames.

    Returns: (NUM_FRAMES, C, H, W) in [-1, 1].
    """
    frame = start_frame.unsqueeze(0).to(device)  # (1, C, H, W)
    posterior = vae.encode(frame.float())
    z0 = posterior.sample() * LATENT_SCALE
    cond = z0.unsqueeze(1).expand(1, NUM_FRAMES, -1, -1, -1)  # (1, T, C_lat, H_lat, W_lat)
    tokens, _ = model.sample_tokens(
        bsz=1,
        cond=cond,
        num_iter=num_iter,
        cfg=1.0,
        temperature=0.95,
    )
    pred = vae.decode(tokens / LATENT_SCALE)  # (T, C, H, W)
    return pred.clamp(-1, 1)


def find_episodes(data_root: Path, max_episodes: int | None = None):
    """Return list of (task_name, demo_name, sorted_frame_paths)."""
    episodes = []
    for task_dir in sorted(data_root.iterdir()):
        if not task_dir.is_dir():
            continue
        for demo_dir in sorted(task_dir.iterdir()):
            if not demo_dir.is_dir():
                continue
            frames_dir = demo_dir / "frames"
            if not frames_dir.is_dir():
                continue
            pngs = sorted(frames_dir.glob("*.png"))
            if len(pngs) >= NUM_FRAMES:
                episodes.append((task_dir.name, demo_dir.name, pngs))
    if max_episodes is not None:
        episodes = episodes[:max_episodes]
    return episodes


def make_grid(frames: list[np.ndarray], pad: int = 2) -> np.ndarray:
    """Horizontally concatenate frames with padding."""
    h = frames[0].shape[0]
    strips = []
    for i, f in enumerate(frames):
        if i > 0:
            strips.append(np.full((h, pad, 3), 255, dtype=np.uint8))
        strips.append(f)
    return np.concatenate(strips, axis=1)


def main():
    p = argparse.ArgumentParser(description="Evaluate video generation on LIBERO episodes")
    p.add_argument("--checkpoint", type=str, default="checkpoints/simple_uva_libero_latest.pt",
                    help="Path to video model checkpoint (relative to video_training/unified_video_action/ or absolute)")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--data-root", type=str, default="/data/libero/parsed_libero/libero_spatial")
    p.add_argument("--out-dir", type=str, default=str(Path(__file__).resolve().parent / "vid_evals"))
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--num-iter", type=int, default=64, help="Diffusion sampling steps")
    p.add_argument("--max-episodes", type=int, default=None, help="Limit number of episodes to evaluate")
    p.add_argument("--stride", type=int, default=4, help="Frame stride between start frames (default: 4 = video pred length)")
    args = p.parse_args()

    device = torch.device(args.device)
    out_root = Path(args.out_dir)

    vae = build_vae(args.vae_ckpt, device)
    model = load_model(args.checkpoint, device)

    episodes = find_episodes(Path(args.data_root), args.max_episodes)
    print(f"Found {len(episodes)} episodes")

    import cv2

    for task_name, demo_name, frame_paths in tqdm(episodes, desc="episodes"):
        ep_out = out_root / task_name / demo_name
        ep_out.mkdir(parents=True, exist_ok=True)

        n_frames = len(frame_paths)
        start_indices = list(range(0, n_frames - NUM_FRAMES + 1, args.stride))

        for start_idx in start_indices:
            # Load GT frames
            gt_frames = []
            for i in range(NUM_FRAMES):
                f = load_frame(frame_paths[start_idx + i])
                gt_frames.append(f)

            # Generate predicted frames from start frame
            pred_frames = generate_video(model, vae, gt_frames[0], device, num_iter=args.num_iter)

            # Convert to uint8 images
            gt_imgs = [tensor_to_uint8(f) for f in gt_frames]
            pred_imgs = [tensor_to_uint8(pred_frames[i]) for i in range(NUM_FRAMES)]

            # Save GT grid and PRED grid
            gt_grid = make_grid(gt_imgs)
            pred_grid = make_grid(pred_imgs)

            gt_path = ep_out / f"GT_start_frame_{start_idx:04d}.png"
            pred_path = ep_out / f"PRED_start_frame_{start_idx:04d}.png"
            cv2.imwrite(str(gt_path), cv2.cvtColor(gt_grid, cv2.COLOR_RGB2BGR))
            cv2.imwrite(str(pred_path), cv2.cvtColor(pred_grid, cv2.COLOR_RGB2BGR))

    print(f"Done. Results saved to {out_root}")


if __name__ == "__main__":
    main()
