"""
Generate a sample image-to-video from a saved Ctrl-World img2vid checkpoint.
Loads one clip from the DROID cache, runs zero-action rollout, saves pred and GT videos.

Usage (from Ctrl-World dir, or from vidgen with correct paths):
  python scripts/sample_video_from_checkpoint.py \\
    --ckpt_path ckpt_ctrl_world_img2vid/ctrl_world_img2vid_step10500.pt \\
    --cache_dir /data/cameron/vidgen/dino_vid_model/vid_cache \\
    --svd_model_path checkpoints/stable-video-diffusion-img2vid \\
    --clip_model_path checkpoints/clip-vit-base-patch32 \\
    --output out/sample_img2vid.mp4
"""
import argparse
import sys
from pathlib import Path

import einops
import numpy as np
import torch

ctrl_world_root = Path(__file__).resolve().parents[1]
if ctrl_world_root.exists() and str(ctrl_world_root) not in sys.path:
    sys.path.insert(0, str(ctrl_world_root))

from config import wm_args
from models.ctrl_world import CrtlWorld
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline

# Inline minimal dataset/conversion (same as train_img2vid) to avoid import issues
CTRLWORLD_NUM_LATENT_FRAMES = 6 + 5
DEFAULT_DROID_CACHE_DIR = Path("/data/cameron/vidgen/dino_vid_model/vid_cache")


class CachedClipDataset:
    def __init__(self, cache_dir: str, num_frames: int = 11, max_load_retries: int = 5):
        cache_dir = Path(cache_dir)
        self.clips = sorted(cache_dir.glob("*.pt"))
        self.num_frames = num_frames
        self.max_load_retries = max_load_retries
        if not self.clips:
            raise FileNotFoundError(f"No .pt clips in {cache_dir}")
    def __len__(self):
        return len(self.clips)
    def _load_one(self, path):
        try:
            out = torch.load(path, map_location="cpu", weights_only=True)
        except (TypeError, EOFError):
            out = torch.load(path, map_location="cpu", weights_only=False)
        if out.dim() == 5:
            out = out[0]
        out = out[:, : self.num_frames]
        return out.unsqueeze(0)
    def __getitem__(self, idx):
        for attempt in range(self.max_load_retries):
            i = (idx + attempt) % len(self.clips) if attempt > 0 else idx
            try:
                return self._load_one(self.clips[i])
            except (EOFError, OSError, RuntimeError):
                continue
        raise RuntimeError(f"Failed to load clip at idx={idx}")


def collate_video_clips(batch):
    return torch.cat(batch, dim=0)


def video_cache_batch_to_latent(model, video, device, args):
    B, C, T, H, W = video.shape
    need = CTRLWORLD_NUM_LATENT_FRAMES
    if T < need:
        pad = video[:, :, -1:].expand(-1, -1, need - T, -1, -1)
        video = torch.cat([video, pad], dim=2)
        T = need
    elif T > need:
        video = video[:, :, :need]
        T = need
    h_svd, w_svd = 192, 320
    if H != h_svd or W != w_svd:
        video = video.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        video = torch.nn.functional.interpolate(video, size=(h_svd, w_svd), mode="bilinear", align_corners=False)
        video = video.reshape(B, T, C, h_svd, w_svd).permute(0, 2, 1, 3, 4)
    vae = model.pipeline.vae
    scaling = vae.config.scaling_factor
    latents_one_view = []
    for t in range(T):
        frame = video[:, :, t]
        with torch.no_grad():
            enc = vae.encode(frame.to(device)).latent_dist.mode()
            enc = enc * scaling
        latents_one_view.append(enc)
    lat = torch.stack(latents_one_view, dim=1)
    lat = lat.repeat(1, 1, 1, 3, 1)
    return {
        "latent": lat,
        "action": torch.zeros(B, need, args.action_dim, device=device, dtype=lat.dtype),
        "text": [""] * B,
    }


def generate_and_save(model, batch, args, device, pred_path, gt_path=None, output_dir=None):
    pipeline = model.pipeline
    num_history = args.num_history
    num_frames = args.num_frames

    video_gt = batch["latent"][:1].to(device, non_blocking=True)
    texts = [batch["text"][0]]
    actions = torch.zeros(
        1, int(num_frames + num_history), args.action_dim,
        device=device, dtype=video_gt.dtype,
    )

    his_latent_gt = video_gt[:, :num_history]
    future_latent_ft = video_gt[:, num_history:]
    current_latent = future_latent_ft[:, 0]

    with torch.no_grad():
        action_latent = model.action_encoder(
            actions, texts, model.tokenizer, model.text_encoder,
            frame_level_cond=args.frame_level_cond,
        )
        _, pred_latents = CtrlWorldDiffusionPipeline.__call__(
            pipeline,
            image=current_latent,
            text=action_latent,
            width=args.width,
            height=int(3 * args.height),
            num_frames=num_frames,
            history=his_latent_gt,
            num_inference_steps=args.num_inference_steps,
            decode_chunk_size=args.decode_chunk_size,
            max_guidance_scale=args.guidance_scale,
            fps=args.fps,
            motion_bucket_id=args.motion_bucket_id,
            mask=None,
            output_type="latent",
            return_dict=False,
            frame_level_cond=args.frame_level_cond,
            his_cond_zero=args.his_cond_zero,
        )

    # pred_latents: (1, num_frames, 4, 72, 40) for 3-view stacked
    assert pred_latents.shape[1] == num_frames, (
        f"Pipeline returned {pred_latents.shape[1]} frames, expected num_frames={num_frames}"
    )
    print(f"Model predicts {num_frames} future frames (shape {pred_latents.shape}).")
    # Diagnostic: are the 5 latent frames actually different? (If all same → model predicts "no motion" with zero actions)
    with torch.no_grad():
        p = pred_latents[0]
        for f in range(num_frames):
            diff = (p[f] - p[0]).abs().mean().item()
            print(f"  Pred latent frame {f}: mean={p[f].mean().item():.4f}, std={p[f].std().item():.4f}, mean |diff from frame 0|={diff:.4f}")

    pred_latents = einops.rearrange(
        pred_latents, "b f c (m h) (n w) -> (b m n) f c h w", m=3, n=1
    )
    video_gt_cat = torch.cat([his_latent_gt, future_latent_ft], dim=1)
    video_gt_cat = einops.rearrange(
        video_gt_cat, "b f c (m h) (n w) -> (b m n) f c h w", m=3, n=1
    )

    decode_chunk = args.decode_chunk_size
    decoded_pred = []
    flat_pred = pred_latents[:1].flatten(0, 1)
    # Decode each frame separately (num_frames=1) so temporal VAE doesn't blend across frames
    for f in range(flat_pred.shape[0]):
        frame_latent = flat_pred[f : f + 1] / pipeline.vae.config.scaling_factor
        dec = pipeline.vae.decode(frame_latent, num_frames=1).sample
        if dec.dim() == 5:
            dec = dec.squeeze(2)
        decoded_pred.append(dec)
    pred_video = torch.cat(decoded_pred, dim=0).reshape(1, -1, *decoded_pred[0].shape[1:])

    n_pred = pred_video.shape[1]
    assert n_pred == num_frames, f"Decoded {n_pred} pred frames, expected {num_frames}"

    pred_np = ((pred_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
    pred_np = pred_np.permute(0, 2, 3, 1).numpy()
    pred_frames = (pred_np * 255).astype(np.uint8)
    try:
        import torchvision
        torchvision.io.write_video(pred_path, torch.from_numpy(pred_frames), fps=4)
    except ImportError:
        import mediapy
        mediapy.write_video(pred_path, pred_frames, fps=4)
    print(f"Saved pred video: {pred_path} ({n_pred} frames)")

    # Save pred frame grid PNG in same output folder
    if output_dir is not None:
        _save_frame_grid(
            pred_frames,
            output_dir,
            Path(pred_path).stem + "_pred_grid.png",
            title_prefix="Pred",
        )

    if gt_path:
        decoded_gt = []
        flat_gt = video_gt_cat[:1].flatten(0, 1)
        for i in range(0, flat_gt.shape[0], decode_chunk):
            chunk = flat_gt[i : i + decode_chunk] / pipeline.vae.config.scaling_factor
            decoded_gt.append(pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample)
        gt_video = torch.cat(decoded_gt, dim=0).reshape(1, -1, *decoded_gt[0].shape[1:])
        gt_np = ((gt_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
        gt_np = gt_np.permute(0, 2, 3, 1).numpy()
        gt_frames = (gt_np * 255).astype(np.uint8)
        try:
            import torchvision
            torchvision.io.write_video(gt_path, torch.from_numpy(gt_frames), fps=4)
        except ImportError:
            import mediapy
            mediapy.write_video(gt_path, gt_frames, fps=4)
        print(f"Saved GT video: {gt_path} ({gt_frames.shape[0]} frames)")
        if output_dir is not None:
            _save_frame_grid(
                gt_frames,
                output_dir,
                Path(gt_path).stem.replace("_gt", "_gt_grid") + ".png",
                title_prefix="GT",
            )


def _save_frame_grid(frames_uint8, output_dir, filename, title_prefix="Frame"):
    """Save frames as a single row image grid PNG. frames_uint8: (N, H, W, 3) uint8."""
    try:
        from PIL import Image
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("Skipping grid PNG (PIL/matplotlib not available)")
        return
    n = frames_uint8.shape[0]
    h, w = frames_uint8.shape[1], frames_uint8.shape[2]
    fig, axes = plt.subplots(1, n, figsize=(2 * n, 2))
    if n == 1:
        axes = [axes]
    for i in range(n):
        axes[i].imshow(frames_uint8[i])
        axes[i].set_title(f"{title_prefix} {i}")
        axes[i].axis("off")
    plt.tight_layout()
    grid_path = Path(output_dir) / filename
    plt.savefig(grid_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved frame grid: {grid_path}")


def main():
    parser = argparse.ArgumentParser(description="Sample image-to-video from Ctrl-World img2vid checkpoint")
    parser.add_argument("--ckpt_path", type=str, required=True, help="Path to ctrl_world_img2vid_step*.pt")
    parser.add_argument("--cache_dir", type=Path, default=DEFAULT_DROID_CACHE_DIR, help="DROID cache with .pt clips")
    parser.add_argument("--svd_model_path", type=str, default=None)
    parser.add_argument("--clip_model_path", type=str, default=None)
    parser.add_argument("--output", type=Path, default=Path("out/sample_img2vid.mp4"), help="Output pred video path")
    parser.add_argument("--save_gt", action="store_true", help="Also save ground-truth video next to output")
    parser.add_argument("--clip_idx", type=int, default=0, help="Index of clip in cache to use")
    parser.add_argument("--device", type=str, default="cuda")
    args_parse = parser.parse_args()

    args = wm_args()
    if args_parse.svd_model_path is not None:
        args.svd_model_path = args_parse.svd_model_path
    if args_parse.clip_model_path is not None:
        args.clip_model_path = args_parse.clip_model_path

    device = torch.device(args_parse.device if torch.cuda.is_available() else "cpu")
    args.ckpt_dir = Path(args_parse.ckpt_path).parent

    print(f"Loading checkpoint: {args_parse.ckpt_path}")
    model = CrtlWorld(args)
    state = torch.load(args_parse.ckpt_path, map_location="cpu", weights_only=True)
    model.load_state_dict(state, strict=True)
    model.to(device)
    model.eval()

    cache_path = Path(args_parse.cache_dir)
    dataset = CachedClipDataset(str(cache_path), num_frames=CTRLWORLD_NUM_LATENT_FRAMES)
    clip = dataset[args_parse.clip_idx]
    batch = collate_video_clips([clip])
    batch = video_cache_batch_to_latent(model, batch.to(device), device, args)

    out_path = Path(args_parse.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    gt_path = None
    if args_parse.save_gt:
        gt_path = out_path.parent / (out_path.stem + "_gt.mp4")

    generate_and_save(
        model, batch, args, device, str(out_path), gt_path,
        output_dir=out_path.parent,
    )


if __name__ == "__main__":
    main()
