"""Train simple_uva on RTX Video dataset: /data/RTX/RTX_Video/*/*.gif, n=4 frames per gif sampled at --sample-fps (default 1)."""

import argparse
import sys
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from einops import rearrange
from tqdm import tqdm

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from simple_uva.dataset import collate_batch, NUM_FRAMES

LATENT_SCALE = 0.2325

DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")


def find_gifs(root: Path):
    """Return sorted list of root/*/*.gif."""
    root = Path(root).resolve()
    out = []
    for subdir in sorted(root.iterdir()):
        if not subdir.is_dir():
            continue
        for f in sorted(subdir.glob("*.gif")):
            out.append(f)
    return out


def _sample_frame_indices_at_fps(durations_ms: list, num_frames: int, fps: float):
    """Frame indices at t=0, 1/fps, 2/fps, ..., (num_frames-1)/fps seconds."""
    if not durations_ms or fps <= 0:
        return list(range(min(num_frames, len(durations_ms) or 1)))
    cumsum_ms = np.cumsum([0] + list(durations_ms))
    total_ms = cumsum_ms[-1]
    indices = []
    for i in range(num_frames):
        t_ms = (i / fps) * 1000
        if t_ms >= total_ms:
            indices.append(len(durations_ms) - 1)
            continue
        k = np.searchsorted(cumsum_ms, t_ms, side="right") - 1
        k = max(0, min(k, len(durations_ms) - 1))
        indices.append(k)
    return indices


class RTXVideoDataset(Dataset):
    """Load n_frames from each GIF sampled at sample_fps (t=0, 1/fps, 2/fps, ...). Returns (1, C, T, H, W) in [-1, 1]."""

    def __init__(
        self,
        root: str | Path,
        num_frames: int = NUM_FRAMES,
        size: int = 256,
        max_samples: int | None = None,
        sample_fps: float = 1.0,
    ):
        self.root = Path(root).resolve()
        self.num_frames = num_frames
        self.size = size
        self.sample_fps = sample_fps
        self.gifs = find_gifs(self.root)
        if not self.gifs:
            raise FileNotFoundError(
                f"No *.gif under {self.root}. Expected structure: root/*/*.gif"
            )
        if max_samples is not None:
            self.gifs = self.gifs[:max_samples]

    def __len__(self):
        return len(self.gifs)

    def __getitem__(self, idx):
        from PIL import Image
        path = self.gifs[idx]
        with Image.open(path) as im:
            n_frames = getattr(im, "n_frames", 1)
            durations_ms = []
            for i in range(n_frames):
                im.seek(i)
                d = im.info.get("duration", 0)
                durations_ms.append(int(d) if d is not None else 0)
            indices = _sample_frame_indices_at_fps(durations_ms, self.num_frames, self.sample_fps)
            frames = []
            for i in indices:
                im.seek(int(i))
                fr = np.array(im.convert("RGB"))
                frames.append(fr)
        if len(frames) < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1].copy())
        frames = np.stack(frames[: self.num_frames], axis=0)
        x = torch.from_numpy(frames).float() / 255.0
        x = x.permute(0, 3, 1, 2)
        x = F.interpolate(x, size=(self.size, self.size), mode="bilinear", align_corners=False)
        x = x.permute(1, 0, 2, 3)
        x = x * 2.0 - 1.0
        return x.unsqueeze(0)


def build_vae(repo_root: Path, vae_ckpt: str, device: torch.device):
    ckpt_path = repo_root / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def main():
    p = argparse.ArgumentParser(description="Train simple_uva on RTX Video dataset (root/*/*.gif, 4 frames @ --sample-fps)")
    p.add_argument("--data-root", type=str, default=str(DEFAULT_DATA_ROOT), help="Root with subdirs of .gif files (default: /data/RTX/RTX_Video)")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--num-iter", type=int, default=64, help="Sampling steps for vis")
    p.add_argument("--name", type=str, default="simple_uva_rtx")
    p.add_argument("--pretrain-mar", type=str, default=None, help="Path to checkpoint to load MAR (and optionally VAE) from, e.g. checkpoints/libero10.ckpt")
    p.add_argument("--max-samples", type=int, default=None, help="Use only the first N GIFs (for quick tests). Default: use all.")
    p.add_argument("--sample-fps", type=float, default=1.0, help="Sample frames at this FPS (t=0, 1/fps, 2/fps, ...). Default: 1.")
    args = p.parse_args()

    device = torch.device(args.device)

    if args.log_wandb:
        import wandb
        wandb.init(project="simple_uva", config=vars(args), name=args.name, mode="online")

    vae = build_vae(REPO_ROOT, args.vae_ckpt, device)
    model = mar_base_video_only(
        img_size=256,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
    ).to(device)

    if args.pretrain_mar:
        import dill
        ckpt_path = Path(args.pretrain_mar)
        if not ckpt_path.is_file():
            raise FileNotFoundError(f"Pretrain checkpoint not found: {ckpt_path}")
        try:
            payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
        except Exception:
            payload = torch.load(ckpt_path, map_location=device, weights_only=False)
        state_dicts = payload.get("state_dicts") or {}
        sd = state_dicts.get("ema_model") or state_dicts.get("model")
        if sd is None:
            raise KeyError(f"No state_dict in {ckpt_path}")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
        model.load_state_dict(model_sd, strict=False)
        if any(k.startswith("vae_model.") for k in sd):
            vae_sd = {k[10:]: v for k, v in sd.items() if k.startswith("vae_model.")}
            vae.load_state_dict(vae_sd, strict=False)
        print(f"Loaded MAR (and VAE if present) from {ckpt_path}")
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    dataset = RTXVideoDataset(
        args.data_root,
        num_frames=NUM_FRAMES,
        size=256,
        max_samples=args.max_samples,
        sample_fps=args.sample_fps,
    )
    print(f"Dataset: {len(dataset)} GIFs ({NUM_FRAMES} frames @ {args.sample_fps} fps each from {args.data_root})")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")
    while True:
        for batch in loader:
            if global_step >= args.steps:
                break
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape
            frames = rearrange(batch, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames.float())
                z = posterior.sample() * LATENT_SCALE
            z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
            z_flat = rearrange(z, "b t c h w -> (b t) c h w")
            x_tokens = model.patchify(z_flat)
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
            cond_tokens = x_tokens[:, :1].expand(-1, T, -1, -1)
            loss = model.compute_loss(x_tokens, cond_tokens)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item()}, step=global_step)

            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                import tempfile
                import torchvision
                with torch.no_grad():
                    first_frame = batch[:1, :, 0].clone()
                    posterior0 = vae.encode(first_frame.float())
                    z0 = posterior0.sample() * LATENT_SCALE
                    cond = z0.unsqueeze(1).expand(1, NUM_FRAMES, -1, -1, -1)
                    tokens, _ = model.sample_tokens(
                        bsz=1,
                        cond=cond,
                        num_iter=args.num_iter,
                        cfg=1.0,
                        temperature=0.95,
                    )
                    pred = vae.decode(tokens / LATENT_SCALE)
                    pred = pred.view(1, NUM_FRAMES, 3, 256, 256)
                pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_path = f.name
                frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                wandb.log(
                    {"vis/denoising_video": wandb.Video(tmp_path, format="mp4")},
                    step=global_step,
                )
                Path(tmp_path).unlink(missing_ok=True)

            global_step += 1
            pbar.update(1)
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_dir = Path(args.checkpoint_dir)
                ckpt_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "step": global_step,
                        "model": model.state_dict(),
                        "optimizer": opt.state_dict(),
                    },
                    ckpt_dir / f"{args.name}_latest.pt",
                )
                print(f"Saved {args.name}_latest.pt")
    pbar.close()
    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
