"""
Minimal training script for diffusion_dino: single-frame VAE tokens + DINO conditioning.
Usage (from vidgen):
  python -m diffusion_dino.train --keygrip_root /path/to/keygrip --vae_ckpt /path/to/vae.ckpt --data_dir /path/to/frames
Frames should be in [0,1] or [-1,1]; will be normalized for DINO and passed to VAE as-is (VAE expects [0,1] or similar).
"""

import argparse
import sys
from pathlib import Path

import torch
import os
from torch.utils.data import DataLoader

# Ensure we can import from unified_video_action and vidgen
vidgen_root = Path(__file__).resolve().parents[1]
uva_root = vidgen_root / "unified_video_action"
for p in [vidgen_root, uva_root]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p))

# Same VAE as unified_video_action (uva config: pretrained_models/vae/kl16.ckpt)
DEFAULT_VAE_CKPT = Path("/data/cameron/vidgen/unified_video_action/pretrained_models/vae/kl16.ckpt")
# Keygrip root used by dino_vid_model (expects keygrip_root/dinov3/weights/...)
DEFAULT_KEYGRIP_ROOT = Path("/data/cameron/keygrip")
DEFAULT_SCRATCH_ROOT = DEFAULT_KEYGRIP_ROOT / "scratch"
DEFAULT_DROID_CACHE_DIR = Path("/data/cameron/vidgen/dino_vid_model/vid_cache")

from types import SimpleNamespace

from simple_uva.vae import AutoencoderKL
from simple_uva.dataset import CachedClipDataset, SelfCollectedDataset, collate_batch

from diffusion_dino import build_dino_diffusion
from diffusion_dino.model import video_to_tokens, tokens_to_video, VAE_LATENT_SCALE, NUM_FRAMES


def load_vae(vae_ckpt: Path, device: torch.device):
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(
        autoencoder_path=str(vae_ckpt) if vae_ckpt and vae_ckpt.exists() else None,
        ddconfig=ddconfig,
    )
    if vae_ckpt and vae_ckpt.exists():
        sd = torch.load(vae_ckpt, map_location="cpu")
        state = sd.get("state_dict", sd)
        vae.load_state_dict(state, strict=False)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False
    vae.to(device)
    return vae


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--keygrip_root",
        type=Path,
        default=DEFAULT_KEYGRIP_ROOT,
        help="Keygrip repo root (default: /data/cameron/keygrip)",
    )
    parser.add_argument("--vae_ckpt", type=Path, default=DEFAULT_VAE_CKPT, help="VAE checkpoint path (default: UVA kl16.ckpt)")
    parser.add_argument(
        "--data_root",
        type=Path,
        default=DEFAULT_SCRATCH_ROOT,
        help="Self-collected episodes root (default: /data/cameron/keygrip/scratch)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="self_collected",
        choices=["self_collected", "droid_cache"],
        help="Dataset source: self_collected (keygrip/scratch) or droid_cache (.pt clips)",
    )
    parser.add_argument(
        "--cache_dir",
        type=Path,
        default=DEFAULT_DROID_CACHE_DIR,
        help="DROID cached clips dir (default: /data/cameron/vidgen/dino_vid_model/vid_cache)",
    )
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--steps", type=int, default=100000000000)
    parser.add_argument("--vis_every", type=int, default=100, help="Visualization/logging sampling cadence")
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_sampling_steps", type=str, default="1", help="Diffusion respacing (e.g. '1', '3', 'ddim3')")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_wandb", action="store_true")
    parser.add_argument("--name", type=str, default="diffusion_dino")
    parser.add_argument("--save_every", type=int, default=500)
    parser.add_argument("--ckpt_dir", type=Path, default=Path("ckpt_diffusion_dino"))
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    if args.log_wandb:
        import wandb

        if "WANDB_API_KEY" in os.environ:
            wandb.login(key=os.environ["WANDB_API_KEY"])
        wandb.init(project="diffusion_dino", config=vars(args), name=args.name, mode="online")

    vae = load_vae(args.vae_ckpt, device)
    model = build_dino_diffusion(
        keygrip_root=args.keygrip_root,
        num_sampling_steps=args.num_sampling_steps,
        denoiser="mlp"
    )
    model.to(device)
    opt = torch.optim.AdamW(model.net.parameters(), lr=args.lr)

    args.ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.dataset == "droid_cache":
        cache_path = args.cache_dir
        dataset = CachedClipDataset(str(cache_path), num_frames=NUM_FRAMES)
        print(f"Dataset: {len(dataset)} clips (cache from {cache_path})")
    else:
        cache_path = args.cache_dir.parent / "vid_cache_keygrip" if hasattr(args.cache_dir, "parent") else Path(args.cache_dir).parent / "vid_cache_keygrip"
        dataset = CachedClipDataset(str(cache_path), num_frames=NUM_FRAMES)
        print(f"Dataset: {len(dataset)} clips (cache from {cache_path})")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    from tqdm import tqdm

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")
    overfit=False # dont remove, testing block
    y=None
    while True:
        for video in loader:
            if overfit:
                if y is None:y=video
                video=y

            video = video.to(device, non_blocking=True)
            with torch.no_grad():
                target_tokens = video_to_tokens(video, vae, scale=VAE_LATENT_SCALE)  # (B, T*256, 16)
                first_frame = video[:, :, 0]  # (B, 3, 256, 256)
                dino_cond = model.get_dino_cond(first_frame)  # (B, 256, D)

            loss = model.compute_loss(target_tokens, dino_cond)*1e2
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            if global_step % 50 == 0:
                print(f"step {global_step} loss {loss.item():.4f}")
            if args.log_wandb:
                import wandb
                wandb.log({"train/loss": float(loss.item())}, step=global_step)

            if args.log_wandb and args.vis_every and (global_step % args.vis_every == 0):
                import wandb

                print("sampling video")
                with torch.no_grad():
                    cond_video = video[:1].clone()
                    cond_frame = cond_video[:, :, 0].clone()  # (1,3,256,256)
                    start_frame = cond_video[:, :, 0].clone()  # frame 0 of the target video
                    dino_cond_1 = model.get_dino_cond(cond_frame)  # (1,256,D)
                    first_tokens = video_to_tokens(cond_video[:, :, :1], vae, scale=VAE_LATENT_SCALE)[:, :256]
                    sampled_tokens = model.sample(first_tokens, dino_cond_1, temperature=1.0, device=device)  # (1,T*256,16)
                    pred_video = tokens_to_video(sampled_tokens, vae, scale=VAE_LATENT_SCALE, t=NUM_FRAMES)  # (1,3,T,256,256)
                print("done sampling video")

                cond_img = ((cond_frame[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                start_img = ((start_frame[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                pred_img0 = ((pred_video[0, :, 0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                wandb.log(
                    {
                        "vis/cond_frame": wandb.Image(cond_img.permute(1, 2, 0).numpy()),
                        "vis/start_frame": wandb.Image(start_img.permute(1, 2, 0).numpy()),
                        "vis/pred_frame0": wandb.Image(pred_img0.permute(1, 2, 0).numpy()),
                    },
                    step=global_step,
                )
                try:
                    import tempfile
                    import torchvision
                    # Predicted video
                    pred_np = ((pred_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)  # (3,T,H,W)
                    pred_np = pred_np.permute(1, 0, 2, 3)  # (T,3,H,W)
                    pred_frames = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")  # (T,H,W,3)
                    # Ground-truth video from cond_video
                    gt_np = ((cond_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)  # (3,T,H,W)
                    gt_np = gt_np.permute(1, 0, 2, 3)  # (T,3,H,W)
                    gt_frames = (gt_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")  # (T,H,W,3)

                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_pred:
                        pred_path = f_pred.name
                    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_gt:
                        gt_path = f_gt.name

                    torchvision.io.write_video(pred_path, torch.from_numpy(pred_frames), fps=4)
                    torchvision.io.write_video(gt_path, torch.from_numpy(gt_frames), fps=4)

                    wandb.log(
                        {
                            "vis/pred_video": wandb.Video(pred_path, format="mp4"),
                            "vis/gt_video": wandb.Video(gt_path, format="mp4"),
                        },
                        step=global_step,
                    )

                    Path(pred_path).unlink(missing_ok=True)
                    Path(gt_path).unlink(missing_ok=True)
                except Exception:
                    pass

            global_step += 1
            pbar.update(1)

            if args.save_every and (global_step % args.save_every == 0):
                ckpt = {
                    "step": global_step,
                    "net": model.net.state_dict(),
                    "optimizer": opt.state_dict(),
                }
                torch.save(ckpt, args.ckpt_dir / f"diffusion_dino_step{global_step}.pt")
                print(f"Saved {args.ckpt_dir / f'diffusion_dino_step{global_step}.pt'}")

    pbar.close()
    if args.log_wandb:
        import wandb

        wandb.finish()


if __name__ == "__main__":
    main()
