"""
Sample from diffusion_dino: condition on a single frame (DINO), denoise in VAE token space, decode.
Usage (from vidgen):
  python -m diffusion_dino.sample --keygrip_root /path/to/keygrip --vae_ckpt /path/to/vae.ckpt --ckpt /path/to/diffusion_dino.pt --cond_frame /path/to/image.png --out /path/to/out.png
"""

import argparse
import sys
from pathlib import Path

import torch

# Ensure we can import from unified_video_action and vidgen
vidgen_root = Path(__file__).resolve().parents[1]
uva_root = vidgen_root / "unified_video_action"
for p in [vidgen_root, uva_root]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p))

from types import SimpleNamespace

from simple_uva.vae import AutoencoderKL

from diffusion_dino import build_dino_diffusion
from diffusion_dino.model import tokens_to_frame, VAE_LATENT_SCALE


def load_vae(vae_ckpt: Path, device: torch.device):
    ddconfig = SimpleNamespace(vae_embed_dim=16, ch_mult=[1, 1, 2, 2, 4])
    vae = AutoencoderKL(
        autoencoder_path=str(vae_ckpt) if vae_ckpt and vae_ckpt.exists() else None,
        ddconfig=ddconfig,
    )
    if vae_ckpt and vae_ckpt.exists():
        sd = torch.load(vae_ckpt, map_location="cpu")
        state = sd.get("state_dict", sd)
        vae.load_state_dict(state, strict=False)
    vae.eval()
    vae.to(device)
    return vae


def load_cond_frame(path: Path, device: torch.device):
    """Load a single image as (1, 3, H, W) in [-1, 1]."""
    try:
        import torchvision
        img = torchvision.io.read_image(str(path)).float() / 255.0
    except Exception:
        from PIL import Image
        import numpy as np
        img = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1).float() / 255.0
    if img.dim() == 3:
        img = img.unsqueeze(0)
    img = img.to(device)
    if img.shape[-1] != 256 or img.shape[-2] != 256:
        img = torch.nn.functional.interpolate(img, size=(256, 256), mode="bilinear", align_corners=False)
    img = img * 2 - 1
    return img


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--keygrip_root", type=Path, required=True, help="Keygrip repo root (DINO weights)")
    parser.add_argument("--vae_ckpt", type=Path, default=None, help="VAE checkpoint path")
    parser.add_argument("--ckpt", type=Path, required=True, help="diffusion_dino checkpoint (net state_dict)")
    parser.add_argument("--cond_frame", type=Path, required=True, help="Conditioning image path")
    parser.add_argument("--out", type=Path, default=Path("out.png"), help="Output image path")
    parser.add_argument("--num_sampling_steps", type=str, default="100")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    vae = load_vae(args.vae_ckpt, device)
    model = build_dino_diffusion(
        keygrip_root=args.keygrip_root,
        num_sampling_steps=args.num_sampling_steps,
    )
    ckpt = torch.load(args.ckpt, map_location=device)
    if "net" in ckpt:
        model.net.load_state_dict(ckpt["net"], strict=True)
    else:
        model.net.load_state_dict(ckpt, strict=True)
    model.to(device)
    model.eval()

    cond = load_cond_frame(args.cond_frame, device)
    with torch.no_grad():
        dino_cond = model.get_dino_cond(cond)
        tokens = model.sample(dino_cond, temperature=args.temperature, device=device)
    tokens = tokens.view(1, 256, 16)
    out = tokens_to_frame(tokens, vae, scale=VAE_LATENT_SCALE)
    out = out[0].clamp(0, 1).cpu()
    try:
        import torchvision
        torchvision.utils.save_image(out, args.out)
    except Exception:
        from PIL import Image
        import numpy as np
        Image.fromarray((out.permute(1, 2, 0).numpy() * 255).astype("uint8")).save(args.out)
    print(f"Saved {args.out}")


if __name__ == "__main__":
    main()
