"""Distributed training (DistributedDataParallel) for RTX Video dataset.

HOW TO RUN (from repo root: unified_video_action):

  One machine, 6 GPUs (effective batch size = batch_size * 6):
    torchrun --nproc_per_node=6 scripts/train_rtx_dist.py --batch-size 16 --steps 10000

  One machine, 4 GPUs:
    torchrun --nproc_per_node=4 scripts/train_rtx_dist.py --batch-size 16

  With wandb and custom name:
    torchrun --nproc_per_node=6 scripts/train_rtx_dist.py --batch-size 16 --log_wandb --name my_rtx_run

  Restrict to specific GPUs (e.g. 0,1,2,3,4,5):
    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nproc_per_node=6 scripts/train_rtx_dist.py --batch-size 16

Note: --batch-size is PER GPU. So with 6 GPUs and --batch-size 16, each step uses 96 samples total.

DDP + DataLoader workers can deadlock (fork + CUDA), so we use num_workers=0 when world_size > 1.
Startup is slow because each rank builds the GIF list from disk; first step then loads data in the
main process. This avoids hangs.
"""

import argparse
import os
import sys
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from einops import rearrange
from tqdm import tqdm

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from simple_uva.dataset import collate_batch, NUM_FRAMES

LATENT_SCALE = 0.2325

DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")


def find_gifs(root: Path):
    """Return sorted list of root/*/*.gif."""
    root = Path(root).resolve()
    out = []
    for subdir in sorted(root.iterdir()):
        if not subdir.is_dir():
            continue
        for f in sorted(subdir.glob("*.gif")):
            out.append(f)
    return out


class RTXVideoDataset(Dataset):
    """Load first n_frames from each GIF under root/*/*.gif. Returns (1, C, T, H, W) in [-1, 1]."""

    def __init__(self, root: str | Path, num_frames: int = NUM_FRAMES, size: int = 256, max_samples: int | None = None):
        self.root = Path(root).resolve()
        self.num_frames = num_frames
        self.size = size
        self.gifs = find_gifs(self.root)
        if not self.gifs:
            raise FileNotFoundError(
                f"No *.gif under {self.root}. Expected structure: root/*/*.gif"
            )
        if max_samples is not None:
            self.gifs = self.gifs[:max_samples]

    def __len__(self):
        return len(self.gifs)

    def __getitem__(self, idx):
        from PIL import Image
        path = self.gifs[idx]
        frames = []
        with Image.open(path) as im:
            n_frames = getattr(im, "n_frames", 1)
            for i in range(min(self.num_frames, n_frames)):
                im.seek(i)
                fr = im.convert("RGB")
                fr = np.array(fr)
                frames.append(fr)
        if len(frames) < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1].copy())
        frames = np.stack(frames[: self.num_frames], axis=0)
        x = torch.from_numpy(frames).float() / 255.0
        x = x.permute(0, 3, 1, 2)
        x = F.interpolate(x, size=(self.size, self.size), mode="bilinear", align_corners=False)
        x = x.permute(1, 0, 2, 3)
        x = x * 2.0 - 1.0
        return x.unsqueeze(0)


def build_vae(repo_root: Path, vae_ckpt: str, device: torch.device):
    ckpt_path = repo_root / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def main():
    p = argparse.ArgumentParser(description="Distributed training: simple_uva on RTX Video (root/*/*.gif)")
    p.add_argument("--data-root", type=str, default=str(DEFAULT_DATA_ROOT))
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--batch-size", type=int, default=4, help="Per-GPU batch size")
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--num-iter", type=int, default=64, help="Sampling steps for vis")
    p.add_argument("--name", type=str, default="simple_uva_rtx_dist")
    p.add_argument("--pretrain-mar", type=str, default=None)
    p.add_argument("--max-samples", type=int, default=None, help="Use only the first N GIFs (for quick tests)")
    args = p.parse_args()

    # Distributed setup: use env set by torchrun
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    if world_size > 1:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    is_main = rank == 0
    if is_main:
        print("Distributed init done, loading model and data...")

    if is_main and args.log_wandb:
        import wandb
        wandb.init(project="simple_uva", config=vars(args), name=args.name, mode="online")

    vae = build_vae(REPO_ROOT, args.vae_ckpt, device)
    model = mar_base_video_only(
        img_size=256,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
    ).to(device)

    if world_size > 1:
        model = DDP(model, device_ids=[local_rank])

    if args.pretrain_mar:
        import dill
        ckpt_path = Path(args.pretrain_mar)
        if not ckpt_path.is_file():
            raise FileNotFoundError(f"Pretrain checkpoint not found: {ckpt_path}")
        try:
            payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
        except Exception:
            payload = torch.load(ckpt_path, map_location=device, weights_only=False)
        state_dicts = payload.get("state_dicts") or {}
        sd = state_dicts.get("ema_model") or state_dicts.get("model")
        if sd is None:
            raise KeyError(f"No state_dict in {ckpt_path}")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
        target = model.module if hasattr(model, "module") else model
        target.load_state_dict(model_sd, strict=False)
        if any(k.startswith("vae_model.") for k in sd):
            vae_sd = {k[10:]: v for k, v in sd.items() if k.startswith("vae_model.")}
            vae.load_state_dict(vae_sd, strict=False)
        if is_main:
            print(f"Loaded MAR (and VAE if present) from {ckpt_path}")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    dataset = RTXVideoDataset(args.data_root, num_frames=NUM_FRAMES, size=256, max_samples=args.max_samples)
    if world_size > 1:
        sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
        shuffle = False
        # Avoid fork + CUDA deadlock: use 0 workers when using DDP.
        n_workers = 0
    else:
        sampler = None
        shuffle = True
        n_workers = args.workers

    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=shuffle,
        sampler=sampler,
        num_workers=n_workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=n_workers > 0,
        prefetch_factor=4 if n_workers > 0 else None,
    )

    if world_size > 1:
        dist.barrier()
    if is_main:
        print(f"Dataset: {len(dataset)} GIFs, {world_size} GPU(s), batch_size {args.batch_size} per GPU (effective {args.batch_size * world_size})")
        print("Starting training...")

    global_step = 0
    if is_main:
        pbar = tqdm(total=args.steps, desc="steps", unit="step")

    first_batch_logged = False
    while global_step < args.steps:
        if world_size > 1:
            sampler.set_epoch(global_step)
        for batch in loader:
            if global_step >= args.steps:
                break
            if is_main and not first_batch_logged:
                print("First batch loaded, stepping...")
                first_batch_logged = True
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape
            frames = rearrange(batch, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames.float())
                z = posterior.sample() * LATENT_SCALE
            z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
            z_flat = rearrange(z, "b t c h w -> (b t) c h w")
            x_tokens = model.module.patchify(z_flat) if hasattr(model, "module") else model.patchify(z_flat)
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
            cond_tokens = x_tokens[:, :1].expand(-1, T, -1, -1)
            loss = model(x_tokens, cond_tokens)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if is_main and args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item()}, step=global_step)

            if is_main and global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                import tempfile
                import torchvision
                run_model = model.module if hasattr(model, "module") else model
                with torch.no_grad():
                    first_frame = batch[:1, :, 0].clone()
                    posterior0 = vae.encode(first_frame.float())
                    z0 = posterior0.sample() * LATENT_SCALE
                    cond = z0.unsqueeze(1).expand(1, NUM_FRAMES, -1, -1, -1)
                    tokens, _ = run_model.sample_tokens(
                        bsz=1,
                        cond=cond,
                        num_iter=args.num_iter,
                        cfg=1.0,
                        temperature=0.95,
                    )
                    pred = vae.decode(tokens / LATENT_SCALE)
                    pred = pred.view(1, NUM_FRAMES, 3, 256, 256)
                pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_path = f.name
                frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                wandb.log(
                    {"vis/denoising_video": wandb.Video(tmp_path, format="mp4")},
                    step=global_step,
                )
                Path(tmp_path).unlink(missing_ok=True)

            global_step += 1
            if is_main:
                pbar.update(1)

            if global_step > 0 and global_step % args.checkpoint_every == 0:
                if world_size > 1:
                    dist.barrier()
                if is_main:
                    ckpt_dir = Path(args.checkpoint_dir)
                    ckpt_dir.mkdir(parents=True, exist_ok=True)
                    state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
                    torch.save(
                        {
                            "step": global_step,
                            "model": state,
                            "optimizer": opt.state_dict(),
                        },
                        ckpt_dir / f"{args.name}_latest.pt",
                    )
                    print(f"Saved {args.name}_latest.pt")

    if is_main:
        pbar.close()
    if args.log_wandb and is_main:
        import wandb
        wandb.finish()

    if world_size > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()
