"""Train simple_uva (MAR + VAE) on DROID clips. Optional wandb; vis = denoising video."""

import argparse
import sys
import time
from pathlib import Path
from types import SimpleNamespace

import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from simple_uva.dataset import CachedClipDataset, DroidVideoDataset, collate_batch, NUM_FRAMES

LATENT_SCALE = 0.2325


def build_vae(repo_root: Path, vae_ckpt: str, device: torch.device):
    ckpt_path = repo_root / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def main():
    p = argparse.ArgumentParser(description="Train simple_uva on DROID")
    p.add_argument("--data-root", type=str, default="/data/weiduoyuan/droid_raw/1.0.1", help="DROID MP4 root")
    p.add_argument("--cache-dir", type=str, default=None, help="Pre-extracted .pt clips (e.g. ../dino_vid_model/vid_cache)")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--our_checkpoint", type=str, default=None)
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--num-iter", type=int, default=64, help="Sampling steps for vis")
    p.add_argument("--name", type=str, default="simple_uva")
    p.add_argument("--pretrain-mar", type=str, default=None, help="Path to checkpoint to load MAR (and optionally VAE) from, e.g. checkpoints/libero10.ckpt")
    args = p.parse_args()

    device = torch.device(args.device)
    if args.log_wandb:
        import wandb
        wandb.init(project="simple_uva", config=vars(args), name=args.name,mode="online")

    vae = build_vae(REPO_ROOT, args.vae_ckpt, device)
    model = mar_base_video_only(
        img_size=256,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
    ).to(device)
    if args.pretrain_mar:
        import dill
        ckpt_path = Path(args.pretrain_mar)
        if not ckpt_path.is_file():
            raise FileNotFoundError(f"Pretrain checkpoint not found: {ckpt_path}")
        try:
            payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
        except Exception:
            payload = torch.load(ckpt_path, map_location=device, weights_only=False)
        state_dicts = payload.get("state_dicts") or {}
        sd = state_dicts.get("ema_model") or state_dicts.get("model")
        if sd is None:
            raise KeyError(f"No state_dict in {ckpt_path}")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
        model.load_state_dict(model_sd, strict=False)
        if any(k.startswith("vae_model.") for k in sd):
            vae_sd = {k[10:]: v for k, v in sd.items() if k.startswith("vae_model.")}
            vae.load_state_dict(vae_sd, strict=False)
        print(f"Loaded MAR (and VAE if present) from {ckpt_path}")
    if args.our_checkpoint:
        our_checkpoint = Path(args.our_checkpoint)
        if not our_checkpoint.is_file():
            raise FileNotFoundError(f"Our checkpoint not found: {our_checkpoint}")
        payload = torch.load(our_checkpoint, map_location=device, pickle_module=dill)
        model.load_state_dict(payload["model"], strict=False)
        print(f"Loaded our checkpoint from {our_checkpoint}")
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    if args.cache_dir:
        dataset = CachedClipDataset(args.cache_dir, num_frames=NUM_FRAMES)
        print(f"Dataset: {len(dataset)} clips (cache)")
    else:
        dataset = DroidVideoDataset(args.data_root, num_frames=NUM_FRAMES, sample_fps=4.0, size=256)
        print(f"Dataset: {len(dataset)} videos")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")
    while global_step < args.steps:
        if args.cache_dir:
            dataset = CachedClipDataset(args.cache_dir, num_frames=NUM_FRAMES)
            if len(dataset) == 0:
                time.sleep(10)
                continue
            loader = DataLoader(
                dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                collate_fn=collate_batch,
                pin_memory=True,
                persistent_workers=args.workers > 0,
                prefetch_factor=4 if args.workers > 0 else None,
            )
        for batch in loader:
            print("done with load")
            if global_step >= args.steps:
                break
            # batch: (B, 3, T, 256, 256) in [-1, 1]
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape
            # Encode: (B*T, 3, 256, 256) -> VAE -> (B*T, 16, 16, 16)
            print("doing encode")
            frames = rearrange(batch, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames.float())
                z = posterior.sample() * LATENT_SCALE
            print("done with encode")
            z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
            # Tokenize: patchify (B, T, 16, 16, 16) -> (B, T, 256, 16)
            z_flat = rearrange(z, "b t c h w -> (b t) c h w")
            print("doing model")
            x_tokens = model.patchify(z_flat)
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
            cond_tokens = x_tokens[:, :1].expand(-1, T, -1, -1)
            loss = model.compute_loss(x_tokens, cond_tokens)
            print("done with model")
            print("doing backward")
            opt.zero_grad()
            loss.backward()
            opt.step()
            print("done with backward")

            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item()}, step=global_step)

            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                with torch.no_grad():
                    first_frame = batch[:1, :, 0].clone()
                    posterior0 = vae.encode(first_frame.float())
                    z0 = posterior0.sample() * LATENT_SCALE
                    cond = z0.unsqueeze(1).expand(1, NUM_FRAMES, -1, -1, -1)
                    tokens, _ = model.sample_tokens(
                        bsz=1,
                        cond=cond,
                        num_iter=args.num_iter,
                        cfg=1.0,
                        temperature=0.95,
                    )
                    pred = vae.decode(tokens / LATENT_SCALE)
                    pred = pred.view(1, NUM_FRAMES, 3, 256, 256)
                pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                # Write to temp file so wandb.Video(path) works without moviepy
                import tempfile
                import torchvision
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_path = f.name
                frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                wandb.log(
                    {"vis/denoising_video": wandb.Video(tmp_path, format="mp4")},
                    step=global_step,
                )
                Path(tmp_path).unlink(missing_ok=True)

            global_step += 1
            pbar.update(1)
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_dir = Path(args.checkpoint_dir)
                ckpt_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "step": global_step,
                        "model": model.state_dict(),
                        "optimizer": opt.state_dict(),
                    },
                    ckpt_dir / f"{args.name}_latest.pt",
                )
                print(f"Saved {args.name}_latest.pt")
            print("doing data load")
    pbar.close()
    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
