"""
Image-to-video finetuning on cached DROID: same interface as diffusion_dino/train.py.
Uses 5-step sampling for visualization and logs pred vs GT 8-frame videos to wandb.

Expects cached clips with at least 8 frames (build cache with num_frames=8 if using
simple_uva/dino_vid_model-style precache).

Usage (from vidgen):
  python -m diffusion_dino.train_img2vid --dataset droid_cache --cache_dir /path/to/vid_cache --log_wandb
"""

import argparse
import os
import sys
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Ensure we can import from unified_video_action and vidgen
vidgen_root = Path(__file__).resolve().parents[1]
uva_root = vidgen_root / "unified_video_action"
for p in [vidgen_root, uva_root]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p))

DEFAULT_VAE_CKPT = Path("/data/cameron/vidgen/unified_video_action/pretrained_models/vae/kl16.ckpt")
DEFAULT_KEYGRIP_ROOT = Path("/data/cameron/keygrip")
DEFAULT_SCRATCH_ROOT = DEFAULT_KEYGRIP_ROOT / "scratch"
DEFAULT_DROID_CACHE_DIR = Path("/data/cameron/vidgen/dino_vid_model/vid_cache")

from types import SimpleNamespace

from simple_uva.vae import AutoencoderKL
from simple_uva.dataset import CachedClipDataset, collate_batch

from diffusion_dino import build_dino_diffusion
from diffusion_dino.model import video_to_tokens, tokens_to_video, VAE_LATENT_SCALE, NUM_FRAMES


def load_vae(vae_ckpt: Path, device: torch.device):
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(
        autoencoder_path=str(vae_ckpt) if vae_ckpt and vae_ckpt.exists() else None,
        ddconfig=ddconfig,
    )
    if vae_ckpt and vae_ckpt.exists():
        sd = torch.load(vae_ckpt, map_location="cpu")
        state = sd.get("state_dict", sd)
        vae.load_state_dict(state, strict=False)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    vae.to(device)
    return vae


def main():
    parser = argparse.ArgumentParser(
        description="Image-to-video finetuning on cached DROID (same interface as diffusion_dino/train.py, 5-step sampling for vis)."
    )
    parser.add_argument("--keygrip_root", type=Path, default=DEFAULT_KEYGRIP_ROOT)
    parser.add_argument("--vae_ckpt", type=Path, default=DEFAULT_VAE_CKPT)
    parser.add_argument("--data_root", type=Path, default=DEFAULT_SCRATCH_ROOT)
    parser.add_argument(
        "--dataset",
        type=str,
        default="droid_cache",
        choices=["self_collected", "droid_cache"],
        help="Dataset source; use droid_cache for cached .pt clips.",
    )
    parser.add_argument(
        "--cache_dir",
        type=Path,
        default=DEFAULT_DROID_CACHE_DIR,
        help="DROID cached clips dir (e.g. dino_vid_model/vid_cache).",
    )
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--steps", type=int, default=100000000000)
    parser.add_argument("--vis_every", type=int, default=100, help="Log pred/GT video to wandb every N steps")
    parser.add_argument(
        "--num_sampling_steps",
        type=str,
        default="5",
        help="Number of denoising steps for sampling (used for vis and training schedule). Default 5.",
    )
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_wandb", action="store_true")
    parser.add_argument("--name", type=str, default="diffusion_dino_img2vid")
    parser.add_argument("--save_every", type=int, default=500)
    parser.add_argument("--ckpt_dir", type=Path, default=Path("ckpt_diffusion_dino_img2vid"))
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    if args.log_wandb:
        import wandb

        if "WANDB_API_KEY" in os.environ:
            wandb.login(key=os.environ["WANDB_API_KEY"])
        wandb.init(project="diffusion_dino", config=vars(args), name=args.name, mode="online")

    vae = load_vae(args.vae_ckpt, device)
    model = build_dino_diffusion(
        keygrip_root=args.keygrip_root,
        num_sampling_steps=args.num_sampling_steps,
        denoiser="mlp",
    )
    model.to(device)
    opt = torch.optim.AdamW(model.net.parameters(), lr=args.lr)

    args.ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.dataset == "droid_cache":
        cache_path = args.cache_dir
        dataset = CachedClipDataset(str(cache_path), num_frames=NUM_FRAMES)
        print(f"Dataset: {len(dataset)} clips (cache from {cache_path}), {NUM_FRAMES} frames per clip")
    else:
        cache_path = getattr(args.cache_dir, "parent", Path(args.cache_dir).parent) / "vid_cache_keygrip"
        dataset = CachedClipDataset(str(cache_path), num_frames=NUM_FRAMES)
        print(f"Dataset: {len(dataset)} clips (cache from {cache_path})")

    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")

    while True:
        for video in loader:
            video = video.to(device, non_blocking=True)
            with torch.no_grad():
                target_tokens = video_to_tokens(video, vae, scale=VAE_LATENT_SCALE)
                first_frame = video[:, :, 0]
                dino_cond = model.get_dino_cond(first_frame)

            loss = model.compute_loss(target_tokens, dino_cond) * 1e2
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            if global_step % 50 == 0:
                print(f"step {global_step} loss {loss.item():.4f}")
            if args.log_wandb:
                import wandb

                wandb.log({"train/loss": float(loss.item())}, step=global_step)

            # Visualization: 5-step sampling, log pred vs GT 8-frame videos
            if args.log_wandb and args.vis_every and (global_step % args.vis_every == 0):
                import wandb
                import tempfile
                import torchvision

                print("Sampling video (5 steps)...")
                with torch.no_grad():
                    cond_video = video[:1].clone()
                    cond_frame = cond_video[:, :, 0].clone()
                    first_tokens = video_to_tokens(
                        cond_video[:, :, :1], vae, scale=VAE_LATENT_SCALE
                    )[:, :256]
                    dino_cond_1 = model.get_dino_cond(cond_frame)
                    sampled_tokens = model.sample(
                        first_tokens, dino_cond_1, temperature=1.0, device=device
                    )
                    pred_video = tokens_to_video(
                        sampled_tokens, vae, scale=VAE_LATENT_SCALE, t=NUM_FRAMES
                    )
                print("Done sampling video")

                cond_img = ((cond_frame[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                pred_img0 = ((pred_video[0, :, 0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                wandb.log(
                    {
                        "vis/cond_frame": wandb.Image(cond_img.permute(1, 2, 0).numpy()),
                        "vis/pred_frame0": wandb.Image(pred_img0.permute(1, 2, 0).numpy()),
                    },
                    step=global_step,
                )

                try:
                    # Predicted video (8 frames)
                    pred_np = ((pred_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                    pred_np = pred_np.permute(1, 0, 2, 3)
                    pred_frames = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                    # Ground-truth video (8 frames)
                    gt_np = ((cond_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                    gt_np = gt_np.permute(1, 0, 2, 3)
                    gt_frames = (gt_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_pred:
                        pred_path = f_pred.name
                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_gt:
                        gt_path = f_gt.name

                    torchvision.io.write_video(
                        pred_path, torch.from_numpy(pred_frames), fps=4
                    )
                    torchvision.io.write_video(
                        gt_path, torch.from_numpy(gt_frames), fps=4
                    )

                    wandb.log(
                        {
                            "vis/pred_video": wandb.Video(pred_path, format="mp4"),
                            "vis/gt_video": wandb.Video(gt_path, format="mp4"),
                        },
                        step=global_step,
                    )

                    Path(pred_path).unlink(missing_ok=True)
                    Path(gt_path).unlink(missing_ok=True)
                except Exception as e:
                    print(f"Warning: could not log videos to wandb: {e}")

            global_step += 1
            pbar.update(1)

            if args.save_every and (global_step % args.save_every == 0):
                ckpt = {
                    "step": global_step,
                    "net": model.net.state_dict(),
                    "optimizer": opt.state_dict(),
                }
                path = args.ckpt_dir / f"diffusion_dino_img2vid_step{global_step}.pt"
                torch.save(ckpt, path)
                print(f"Saved {path}")

            if global_step >= args.steps:
                pbar.close()
                if args.log_wandb:
                    import wandb
                    wandb.finish()
                return


if __name__ == "__main__":
    main()
