"""
Test MAR (UVA KL-16) tokenizer on Droid: load video -> encode -> decode -> compare.
Mirrors VidTok/my_tokenize_and_reconstruct.py but uses the UVA VAE (kl16.ckpt).
"""

import argparse
import sys
from pathlib import Path

import numpy as np
import torch
from torchvision.utils import make_grid, save_image

# Droid dataset lives in vidgen
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

# UVA VAE: add unified_video_action so we can import their vaekl
UVA_ROOT = REPO_ROOT / "unified_video_action"
if not UVA_ROOT.is_dir():
    raise FileNotFoundError(f"UVA repo not found at {UVA_ROOT}")
sys.path.insert(0, str(UVA_ROOT))

from unified_video_action.vae.vaekl import AutoencoderKL
from omegaconf import OmegaConf

from dino_vid_model.dataset import DroidVideoDataset, collate_batch


def load_vae(ckpt_path: str, device: torch.device):
    ddconfig = OmegaConf.create({"vae_embed_dim": 16, "ch_mult": [1, 1, 2, 2, 4]})
    vae = AutoencoderKL(autoencoder_path=ckpt_path, ddconfig=ddconfig)
    vae.eval().to(device)
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def to_01_video(x):
    """(B, 3, T, H, W) in [-1, 1] -> (B*T, 3, H, W) in [0, 1] for grid."""
    x = x.permute(0, 2, 1, 3, 4)  # (B, T, 3, H, W)
    x = ((x + 1.0) / 2.0).clamp(0, 1)
    return x  # (B, T, 3, H, W)


def main():
    p = argparse.ArgumentParser(description="MAR tokenizer test on Droid")
    p.add_argument("--data_root", type=str, default=None, help="Droid MP4 root (required unless --video)")
    p.add_argument("--video", type=str, default=None, help="Single video path (overrides data_root)")
    p.add_argument("--num_frames", type=int, default=8)
    p.add_argument("--ckpt", type=str, default=None, help="Path to kl16.ckpt (default: UVA pretrained_models/vae/kl16.ckpt)")
    p.add_argument("--outdir", type=str, default="out/tokenize_droid_mar")
    p.add_argument("--device", type=str, default="cuda")
    args = p.parse_args()

    if args.ckpt is None:
        args.ckpt = str(UVA_ROOT / "pretrained_models" / "vae" / "kl16.ckpt")
    if not Path(args.ckpt).is_file():
        raise FileNotFoundError(f"VAE checkpoint not found: {args.ckpt}")

    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    device = torch.device(args.device)

    # Load one clip: (B, 3, T, H, W) in [-1, 1]
    if args.video:
        video_path = Path(args.video)
        if not video_path.is_file():
            raise FileNotFoundError(f"Video not found: {args.video}")

        def load_video_frames(path, num_frames, size=256):
            import imageio.v3 as iio
            frames = [f for f in iio.imiter(str(path), plugin="FFMPEG")]
            if len(frames) < num_frames:
                raise ValueError(f"Video has {len(frames)} frames, need {num_frames}")
            frames = frames[:num_frames]
            x = torch.from_numpy(np.stack(frames)).float() / 255.0  # (T, H, W, C)
            x = x.permute(0, 3, 1, 2)  # (T, C, H, W)
            x = torch.nn.functional.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
            x = x.permute(1, 0, 2, 3).unsqueeze(0)  # (1, 3, T, H, W)
            return x * 2.0 - 1.0

        x = load_video_frames(video_path, args.num_frames)
    else:
        if not args.data_root:
            raise ValueError("Provide --data_root or --video")
        ds = DroidVideoDataset(root=args.data_root, num_frames=args.num_frames, sample_fps=4.0, size=256)
        x = collate_batch([ds[0]])  # (1, 3, T, H, W)

    x = x.to(device)

    vae = load_vae(args.ckpt, device)

    # Encode/decode per frame (same as UVA's extract_latent_autoregressive / decode)
    B, C, T, H, W = x.shape
    frames_flat = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)  # (B*T, 3, 256, 256)

    with torch.no_grad():
        posterior = vae.encode(frames_flat)
        z = posterior.sample()  # (B*T, 16, 16, 16)
        x_recon = vae.decode(z)  # (B*T, 3, 256, 256)

    x_recon = x_recon.view(B, T, C, H, W).permute(0, 2, 1, 3, 4)  # (B, 3, T, H, W)

    # Save latents for inspection
    torch.save({"z": z.detach().cpu()}, outdir / "tokens.pt")
    print("Input shape:", tuple(x.shape))
    print("Latent z shape:", tuple(z.shape))
    print("Recon shape:", tuple(x_recon.shape))

    orig = to_01_video(x)      # (B, T, 3, H, W)
    recon = to_01_video(x_recon)

    mse = torch.mean((orig - recon) ** 2).item()
    psnr = -10.0 * np.log10(max(mse, 1e-12))
    print(f"MSE:  {mse:.6f}")
    print(f"PSNR: {psnr:.3f} dB")

    # Grids: treat frames as batch dim for make_grid
    orig_flat = orig.reshape(-1, 3, H, W)
    recon_flat = recon.reshape(-1, 3, H, W)
    save_image(make_grid(orig_flat, nrow=4), outdir / "original_grid.png")
    save_image(make_grid(recon_flat, nrow=4), outdir / "recon_grid.png")
    paired = torch.cat([orig_flat, recon_flat], dim=3)
    save_image(make_grid(paired, nrow=4), outdir / "comparison_grid.png")

    print(f"Saved outputs to {outdir}")


if __name__ == "__main__":
    main()
