"""Train simple_uva on self-collected episodes (keygrip/scratch). Same as train_simple_uva but data from root/task/episode/*.png."""

import argparse
import sys
import time
from pathlib import Path
from types import SimpleNamespace

import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from simple_uva.vae import AutoencoderKL
from simple_uva.model import mar_base_video_only
from simple_uva.dataset import SelfCollectedDataset, collate_batch, NUM_FRAMES

LATENT_SCALE = 0.2325

# Default: ../../keygrip/scratch from repo root (unified_video_action); parent=vidgen, parents[1]=above vidgen
DEFAULT_SCRATCH = REPO_ROOT.resolve().parents[1] / "keygrip" / "scratch"


def build_vae(repo_root: Path, vae_ckpt: str, device: torch.device):
    ckpt_path = repo_root / vae_ckpt if not Path(vae_ckpt).is_absolute() else Path(vae_ckpt)
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(autoencoder_path=str(ckpt_path) if ckpt_path.exists() else None, ddconfig=ddconfig)
    vae.to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    return vae


def main():
    p = argparse.ArgumentParser(description="Train simple_uva on self-collected episodes (keygrip/scratch)")
    p.add_argument("--data-root", type=str, default=str(DEFAULT_SCRATCH), help="Root with task/episode/*.png (default: ../../keygrip/scratch)")
    p.add_argument("--vae-ckpt", type=str, default="pretrained_models/vae/kl16.ckpt")
    p.add_argument("--batch-size", type=int, default=4)
    p.add_argument("--workers", type=int, default=8)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--steps", type=int, default=10000)
    p.add_argument("--vis-every", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--log_wandb", action="store_true")
    p.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    p.add_argument("--checkpoint-every", type=int, default=1000)
    p.add_argument("--num-iter", type=int, default=64, help="Sampling steps for vis")
    p.add_argument("--name", type=str, default="simple_uva_self_collected")
    p.add_argument("--pretrain-mar", type=str, default=None, help="Path to checkpoint to load MAR (and optionally VAE) from, e.g. checkpoints/libero10.ckpt")
    args = p.parse_args()

    device = torch.device(args.device)
    if args.log_wandb:
        import wandb
        wandb.init(project="simple_uva", config=vars(args), name=args.name, mode="online")

    vae = build_vae(REPO_ROOT, args.vae_ckpt, device)
    model = mar_base_video_only(
        img_size=256,
        vae_stride=16,
        patch_size=1,
        vae_embed_dim=16,
        num_sampling_steps="100",
        diffloss_d=6,
        diffloss_w=1024,
    ).to(device)
    if args.pretrain_mar:
        import dill
        ckpt_path = Path(args.pretrain_mar)
        if not ckpt_path.is_file():
            raise FileNotFoundError(f"Pretrain checkpoint not found: {ckpt_path}")
        try:
            payload = torch.load(ckpt_path, map_location=device, pickle_module=dill)
        except Exception:
            payload = torch.load(ckpt_path, map_location=device, weights_only=False)
        state_dicts = payload.get("state_dicts") or {}
        sd = state_dicts.get("ema_model") or state_dicts.get("model")
        if sd is None:
            raise KeyError(f"No state_dict in {ckpt_path}")
        model_sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
        model.load_state_dict(model_sd, strict=False)
        if any(k.startswith("vae_model.") for k in sd):
            vae_sd = {k[10:]: v for k, v in sd.items() if k.startswith("vae_model.")}
            vae.load_state_dict(vae_sd, strict=False)
        print(f"Loaded MAR (and VAE if present) from {ckpt_path}")
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    dataset = SelfCollectedDataset(
        args.data_root,
        num_frames=NUM_FRAMES,
        size=256,
        sequence_length=8,
    )
    print(f"Dataset: {len(dataset)} episodes (self-collected from {args.data_root})")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")
    while global_step < args.steps:
        for batch in loader:
            if global_step >= args.steps:
                break
            batch = batch.to(device, non_blocking=True)
            B, C, T, H, W = batch.shape
            frames = rearrange(batch, "b c t h w -> (b t) c h w")
            with torch.no_grad():
                posterior = vae.encode(frames.float())
                z = posterior.sample() * LATENT_SCALE
            z = rearrange(z, "(b t) c h w -> b t c h w", b=B)
            z_flat = rearrange(z, "b t c h w -> (b t) c h w")
            x_tokens = model.patchify(z_flat)
            x_tokens = rearrange(x_tokens, "(b t) s c -> b t s c", b=B)
            cond_tokens = x_tokens[:, :1].expand(-1, T, -1, -1)
            loss = model.compute_loss(x_tokens, cond_tokens)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": loss.item()}, step=global_step)

            if global_step % args.vis_every == 0 and args.log_wandb:
                import wandb
                import tempfile
                import torchvision
                with torch.no_grad():
                    first_frame = batch[:1, :, 0].clone()
                    posterior0 = vae.encode(first_frame.float())
                    z0 = posterior0.sample() * LATENT_SCALE
                    cond = z0.unsqueeze(1).expand(1, NUM_FRAMES, -1, -1, -1)
                    tokens, _ = model.sample_tokens(
                        bsz=1,
                        cond=cond,
                        num_iter=args.num_iter,
                        cfg=1.0,
                        temperature=0.95,
                    )
                    pred = vae.decode(tokens / LATENT_SCALE)
                    pred = pred.view(1, NUM_FRAMES, 3, 256, 256)
                pred_np = ((pred[0].cpu() + 1.0) / 2.0).clamp(0, 1)
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
                    tmp_path = f.name
                frames_np = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
                torchvision.io.write_video(tmp_path, torch.from_numpy(frames_np), fps=4)
                wandb.log(
                    {"vis/denoising_video": wandb.Video(tmp_path, format="mp4")},
                    step=global_step,
                )
                Path(tmp_path).unlink(missing_ok=True)

            global_step += 1
            pbar.update(1)
            if global_step > 0 and global_step % args.checkpoint_every == 0:
                ckpt_dir = Path(args.checkpoint_dir)
                ckpt_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "step": global_step,
                        "model": model.state_dict(),
                        "optimizer": opt.state_dict(),
                    },
                    ckpt_dir / f"{args.name}_latest.pt",
                )
                print(f"Saved {args.name}_latest.pt")
    pbar.close()
    if args.log_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
