"""Train DinoVideoModel (3-step diffusion) on LIBERO parsed data.

Predicts future VAE latent tokens from a single DINO-encoded frame via diffusion.
Training: random noise level per sample, predict noise residual (MSE loss).
Visualization: 3-step reverse diffusion to generate predicted video frames.
"""

import argparse
import sys
from pathlib import Path
from types import SimpleNamespace

import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm

# Path setup for UVA imports
UVA_ROOT = Path(__file__).resolve().parent.parent / "unified_video_action"
sys.path.insert(0, str(UVA_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.dataset import LiberoVideoDataset, collate_batch

from model import DinoVideoModel

LATENT_SCALE = 0.2325
N_FRAMES = 6


def build_vae(vae_ckpt: str, device: torch.device):
    """Build and freeze the VAE for target encoding."""
    ckpt_path = UVA_ROOT / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def to_imagenet_normalized(frame, device):
    """Convert frame from [-1, 1] to ImageNet-normalized [0, 1] range."""
    frame_01 = (frame + 1.0) / 2.0
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    return (frame_01 - mean) / std


def main():
    p = argparse.ArgumentParser(description="Train DinoVideoModel (diffusion) on LIBERO")
    p.add_argument("--data-root", type=str, default="/data/libero/parsed_libero/libero_spatial")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=500)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--run-name", type=str, default="dino_video_diffusion_libero")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--vis-every", type=int, default=200)
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--checkpoint-dir", type=str, default="video_training/custom_dino_video/checkpoints")
    p.add_argument("--frame-stride", type=int, default=3)
    p.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume from")
    args = p.parse_args()

    device = torch.device(args.device)

    if args.log_wandb:
        import wandb
        wandb.init(project="dino_video", config=vars(args), name=args.run_name, mode="online")

    # Build models
    vae = build_vae(args.vae_ckpt, device)
    model = DinoVideoModel(n_frames=N_FRAMES, vae_embed_dim=16).to(device)

    # Only optimize non-frozen parameters (DINO is frozen)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(trainable_params, lr=args.lr)
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")

    global_step = 0
    if args.resume and Path(args.resume).exists():
        ckpt = torch.load(args.resume, map_location=device)
        model.load_state_dict(ckpt["model"], strict=False)
        if "optimizer" in ckpt:
            try:
                opt.load_state_dict(ckpt["optimizer"])
            except Exception as e:
                print(f"Could not load optimizer state: {e}")
        global_step = ckpt.get("step", 0)
        print(f"Resumed from {args.resume} at step {global_step}")

    # Dataset
    dataset = LiberoVideoDataset(
        root=args.data_root,
        num_frames=N_FRAMES,
        size=256,
        frame_stride=args.frame_stride,
    )
    print(f"Dataset: {len(dataset)} episodes")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    for epoch in range(args.epochs):
        pbar = tqdm(loader, desc=f"epoch {epoch}", unit="batch")
        for batch in pbar:
            # batch: (B, 3, T=6, 256, 256) in [-1, 1]
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape

            # --- First frame through DINO (conditioning) ---
            first_frame = batch[:, :, 0]  # (B, 3, 256, 256) in [-1, 1]
            first_frame_inet = to_imagenet_normalized(first_frame, device)

            # --- Target: all T frames through frozen VAE ---
            with torch.no_grad():
                frames_flat = rearrange(batch, "b c t h w -> (b t) c h w")
                posterior = vae.encode(frames_flat.float())
                z_0 = posterior.sample() * LATENT_SCALE  # (B*T, 16, 16, 16)
                z_0 = rearrange(z_0, "(b t) c h w -> b c t h w", b=B)  # (B, 16, T, 16, 16)

            # --- Diffusion training: predict clean x₀ ---
            x0_pred, x0_target, t = model(first_frame_inet, z_0)
            loss = torch.nn.functional.mse_loss(x0_pred, x0_target)

            opt.zero_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix(loss=f"{loss.item():.4f}", step=global_step)

            if args.log_wandb:
                import wandb
                wandb.log({
                    "train/loss": loss.item(),
                    "train/t_mean": t.float().mean().item(),
                }, step=global_step)

            # --- Visualization: run 3-step reverse diffusion ---
            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                import tempfile
                import torchvision

                with torch.no_grad():
                    # Sample via 3-step reverse diffusion
                    pred_tokens = model.sample(first_frame_inet[:1])  # (1, 16, T, 16, 16)

                    # Decode predicted tokens through VAE
                    pred_flat = rearrange(pred_tokens, "b c t h w -> (b t) c h w")
                    decoded = vae.decode(pred_flat / LATENT_SCALE)  # (T, 3, 256, 256) in [-1, 1]
                    decoded = decoded.view(1, T, 3, 256, 256)

                    # Also decode the GT for comparison
                    gt_flat = rearrange(z_0[:1], "b c t h w -> (b t) c h w")
                    gt_decoded = vae.decode(gt_flat / LATENT_SCALE)
                    gt_decoded = gt_decoded.view(1, T, 3, 256, 256)

                # Predicted video
                pred_np = ((decoded[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                pred_frames = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                # GT video
                gt_np = ((gt_decoded[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                gt_frames = (gt_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                vis_log = {}
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_pred = f.name
                torchvision.io.write_video(tmp_pred, torch.from_numpy(pred_frames), fps=4)
                vis_log["vis/predicted_video"] = wandb.Video(tmp_pred, format="mp4")

                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_gt = f.name
                torchvision.io.write_video(tmp_gt, torch.from_numpy(gt_frames), fps=4)
                vis_log["vis/gt_video"] = wandb.Video(tmp_gt, format="mp4")

                # Side-by-side strip: first frame of each
                import numpy as np
                strip_pred = np.concatenate([pred_frames[t] for t in range(T)], axis=1)
                strip_gt = np.concatenate([gt_frames[t] for t in range(T)], axis=1)
                strip = np.concatenate([strip_pred, strip_gt], axis=0)
                vis_log["vis/pred_vs_gt_strip"] = wandb.Image(strip, caption="Top: predicted, Bottom: GT")

                wandb.log(vis_log, step=global_step)
                Path(tmp_pred).unlink(missing_ok=True)
                Path(tmp_gt).unlink(missing_ok=True)

            # --- Checkpoint ---
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_dir = Path(args.checkpoint_dir)
                ckpt_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "step": global_step,
                        "model": model.state_dict(),
                        "optimizer": opt.state_dict(),
                    },
                    ckpt_dir / "latest.pt",
                )
                print(f"Saved checkpoint at step {global_step}")

            global_step += 1
        print(f"Epoch {epoch} done, global_step={global_step}")

    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
