"""
Minimal UVA-style video prediction training script on the Droid dataset.

This script:
- Loads DroidVideoDataset from dino_vid_model.dataset.
- Uses a MAR-style VAE (via MARVAEWrapper) to tokenize 256x256 RGB frames.
- Trains a UVAVideoTransformer to predict future-frame latents given only
  the first-frame latents, always masking out all future frames during training.
- Decodes predicted latents back to RGB and uses an L2 loss on pixels.

IMPORTANT:
- You must clone and install the MAR repo and download its VAE checkpoint.
- Pass the MAR VAE checkpoint path with --mar-ckpt.
"""

import argparse
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from dino_vid_model.dataset import DroidVideoDataset, collate_batch
from uva.mar_tokenizer import MARVAEWrapper
from uva.model import UVAConfig, UVAVideoTransformer


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--data-root", type=str, default="/data/weiduoyuan/droid_raw/1.0.1")
    p.add_argument("--num-frames", type=int, default=8)
    p.add_argument("--sample-fps", type=float, default=4.0)
    p.add_argument("--size", type=int, default=256)
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--mar-ckpt", type=str, required=True, help="Path to MAR VAE checkpoint (e.g. kl16.ckpt)")
    p.add_argument("--save-dir", type=str, default="uva_checkpoints")
    p.add_argument("--save-every", type=int, default=1000)
    return p.parse_args()


def main():
    args = parse_args()
    device = torch.device(args.device)

    # MAR VAE tokenizer
    mar_vae = MARVAEWrapper(args.mar_ckpt, device=device)
    c_l, h_l, w_l = mar_vae.latent_shape

    # UVA video transformer
    cfg = UVAConfig(
        latent_channels=c_l,
        latent_height=h_l,
        latent_width=w_l,
        num_frames=args.num_frames,
    )
    model = UVAVideoTransformer(cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    # Droid dataset and loader
    dataset = DroidVideoDataset(
        root=args.data_root,
        num_frames=args.num_frames,
        sample_fps=args.sample_fps,
        size=args.size,
    )
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
    )

    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")

    mse = nn.MSELoss()

    while step < args.steps:
        for x in loader:
            if step >= args.steps:
                break
            x = x.to(device, non_blocking=True)  # (B, 3, T, H, W) in [-1,1]

            # Encode all frames into MAR latent space
            b, c, t, h, w = x.shape
            assert t == args.num_frames
            frames = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)  # (B*T, 3, H, W)
            with torch.no_grad():
                z_all = mar_vae.encode(frames)  # (B*T, C_l, H_l, W_l)
            z_all = z_all.view(b, t, c_l, h_l, w_l)
            z_first = z_all[:, 0]          # (B, C_l, H_l, W_l)
            z_future_gt = z_all[:, 1:]     # (B, T-1, C_l, H_l, W_l)

            model.train()
            pred_future, target_tokens = model(z_first, z_future=z_future_gt)
            # Decode predicted latents to RGB and compute pixel L2 vs GT future frames
            pred_frames = mar_vae.decode(
                pred_future.reshape(b * (t - 1), c_l, h_l, w_l)
            ).view(b, t - 1, 3, args.size, args.size)
            gt_frames = x.permute(0, 2, 1, 3, 4)[:, 1:]  # (B, T-1, 3, H, W)
            loss = mse(pred_frames, gt_frames)

            opt.zero_grad()
            loss.backward()
            opt.step()

            step += 1
            pbar.update(1)
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

            if step > 0 and step % args.save_every == 0:
                ckpt_path = save_dir / f"uva_step_{step}.pt"
                torch.save(
                    {
                        "step": step,
                        "model": model.state_dict(),
                        "optimizer": opt.state_dict(),
                        "config": cfg.__dict__,
                    },
                    ckpt_path,
                )

    pbar.close()


if __name__ == "__main__":
    main()

