# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F


def _ensure_vidgen_paths():
    # Script lives in cosmos-predict2.5/scripts/; vidgen is typically the parent of cosmos-predict2.5
    script_dir = Path(__file__).resolve().parent
    cosmos_root = script_dir.parent
    vidgen_root = cosmos_root.parent
    uva_root = vidgen_root / "unified_video_action"
    import sys

    for p in [vidgen_root, uva_root, cosmos_root]:
        if p.exists() and str(p) not in sys.path:
            sys.path.insert(0, str(p))
    return cosmos_root, vidgen_root, uva_root


def get_cosmos_video_spec(model):
    from cosmos_predict2._src.predict2.datasets.utils import VIDEO_RES_SIZE_INFO

    req_T = model.tokenizer.pixel_chunk_duration
    res_key = getattr(model.config, "resolution", "720")
    if isinstance(res_key, str) and res_key in VIDEO_RES_SIZE_INFO:
        ar_info = VIDEO_RES_SIZE_INFO[res_key]
        if isinstance(ar_info, dict):
            req_H, req_W = ar_info.get("1,1", (960, 960))
        else:
            req_H, req_W = 960, 960
    else:
        req_H, req_W = 704, 1280
    return req_T, req_H, req_W


@torch.no_grad()
def make_data_batch_from_first_frame(
    clip_video: torch.Tensor,  # (C, T, H, W) in [-1, 1]
    device: torch.device,
    req_T: int,
    req_H: int,
    req_W: int,
    prompt: str,
    num_conditional_frames: int = 1,
    fps_val: float = 16.0,
) -> dict:
    # take first frame
    first = clip_video[:, :1]  # (C,1,H,W)
    # resize to model res
    if first.shape[-2:] != (req_H, req_W):
        first = F.interpolate(first.unsqueeze(0), size=(1, req_H, req_W), mode="trilinear", align_corners=False)[0]

    # build a full-length conditioning video: first frame + zeros
    video = torch.zeros((3, req_T, req_H, req_W), dtype=first.dtype, device=first.device)
    video[:, :num_conditional_frames] = first[:, :1].expand(-1, num_conditional_frames, -1, -1)

    # [-1,1] -> uint8 [0,255]
    video_u8 = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).round().clamp(0, 255).to(torch.uint8)
    video_u8 = video_u8.unsqueeze(0).to(device)  # (1,3,T,H,W)

    fps = torch.full((1,), float(fps_val), device=device, dtype=torch.float32)
    padding_mask = torch.zeros((1, 1, req_H, req_W), device=device, dtype=torch.float32)

    return {
        "dataset_name": "video_data",
        "video": video_u8,
        "ai_caption": [prompt],
        "fps": fps,
        "padding_mask": padding_mask,
        "num_conditional_frames": int(num_conditional_frames),
    }


def main():
    cosmos_root, _, uva_root = _ensure_vidgen_paths()

    parser = argparse.ArgumentParser(description="Sample Cosmos Predict2.5 I2V from a DROID cached clip's first frame.")
    parser.add_argument("--cache_dir", type=Path, default=Path("/data/cameron/vidgen/dino_vid_model/vid_cache"))
    parser.add_argument("--index", type=int, default=0, help="Clip index in CachedClipDataset")
    parser.add_argument("--out_dir", type=Path, default=cosmos_root / "out")
    parser.add_argument("--prompt", type=str, default="Robot manipulation.")
    parser.add_argument("--num_steps", type=int, default=5)
    parser.add_argument("--guidance", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_conditional_frames", type=int, default=1)
    parser.add_argument("--fps", type=float, default=16.0)
    parser.add_argument("--resolution", type=str, default="256", help="Override model resolution (e.g. 256, 480, 512, 720)")
    parser.add_argument(
        "--hf_token",
        type=str,
        default=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
        help="Optional Hugging Face token for gated checkpoints.",
    )
    # checkpoint/model selection
    parser.add_argument("--ckpt_path", type=str, default="81edfebe-bd6a-4039-8c1d-737df1a790bf")
    parser.add_argument(
        "--experiment",
        type=str,
        default="Stage-c_pt_4-Index-2-Size-2B-Res-720-Fps-16-Note-rf_with_edm_ckpt",
    )
    args = parser.parse_args()

    token = (args.hf_token or "").strip() or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
    if token:
        os.environ["HF_TOKEN"] = token
        os.environ["HUGGINGFACE_HUB_TOKEN"] = token
        print(f"HF token detected (len={len(token)}).")
    else:
        print("Warning: no HF token found; gated downloads may 403.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.out_dir.mkdir(parents=True, exist_ok=True)

    # Load dataset module only (avoid importing package __init__)
    import importlib.util

    dataset_path = uva_root / "simple_uva" / "dataset.py"
    spec = importlib.util.spec_from_file_location("_uva_dataset", dataset_path)
    dataset_mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(dataset_mod)
    CachedClipDataset = dataset_mod.CachedClipDataset

    ds = CachedClipDataset(str(args.cache_dir), num_frames=8)
    clip = ds[args.index]
    # dataset returns dict; try common keys
    if isinstance(clip, dict):
        vid = clip.get("video") or clip.get("frames") or clip.get("clip")
    else:
        vid = clip
    if not isinstance(vid, torch.Tensor):
        raise TypeError(f"Unexpected dataset item type: {type(vid)} keys={list(clip.keys()) if isinstance(clip, dict) else None}")
    # ensure (C,T,H,W)
    if vid.ndim == 5:
        vid = vid[0]
    if vid.shape[0] != 3:
        raise ValueError(f"Expected 3 channels, got {vid.shape}")

    # Load Cosmos model
    from cosmos_predict2._src.predict2.utils.model_loader import load_model_from_checkpoint

    try:
        from cosmos_oss.checkpoints_predict2 import register_checkpoints

        register_checkpoints()
    except Exception:
        pass

    config_file = "cosmos_predict2/_src/predict2/configs/video2world/config.py"
    experiment_opts = ["data_train=mock_video", "data_val=mock"]
    if args.resolution:
        experiment_opts.append(f"model.config.resolution={args.resolution}")

    model, _config = load_model_from_checkpoint(
        experiment_name=args.experiment,
        s3_checkpoint_dir=args.ckpt_path,
        config_file=config_file,
        experiment_opts=experiment_opts,
        enable_fsdp=False,
        load_ema_to_reg=False,
        instantiate_ema=True,
        seed=0,
        local_cache_dir=None,
    )
    model.eval().to(device)

    req_T, req_H, req_W = get_cosmos_video_spec(model)
    print(f"Cosmos video spec: T={req_T}, H={req_H}, W={req_W}")

    data_batch = make_data_batch_from_first_frame(
        clip_video=vid.to(device),
        device=device,
        req_T=req_T,
        req_H=req_H,
        req_W=req_W,
        prompt=args.prompt,
        num_conditional_frames=args.num_conditional_frames,
        fps_val=args.fps,
    )

    # generate_samples_from_batch expects text embeddings to already be present in the batch.
    # (During training, model.forward computes these online, but sampling uses conditioner directly.)
    if getattr(model, "text_encoder", None) is not None:
        text_embeddings = model.text_encoder.compute_text_embeddings_online(data_batch, "ai_caption")
        data_batch["t5_text_embeddings"] = text_embeddings
        data_batch["t5_text_mask"] = torch.ones(
            text_embeddings.shape[0], text_embeddings.shape[1], device=text_embeddings.device
        )

    autocast_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if device.type == "cuda" else torch.no_grad()
    with torch.no_grad():
        with (autocast_ctx if device.type == "cuda" else torch.no_grad()):
            latent = model.generate_samples_from_batch(
                data_batch,
                guidance=float(args.guidance),
                seed=int(args.seed),
                n_sample=1,
                num_steps=int(args.num_steps),
            )
            video_pred = model.decode(latent)  # (1, C, T, H, W) in [-1,1]

    pred = video_pred[0].float().clamp(-1, 1)
    pred_u8 = ((pred + 1.0) * 127.5).round().clamp(0, 255).to(torch.uint8)  # (C,T,H,W)
    pred_u8_thwc = pred_u8.permute(1, 2, 3, 0).cpu().numpy()  # (T,H,W,C)

    out_mp4 = args.out_dir / f"sample_droid_idx{args.index}_res{req_H}x{req_W}_steps{args.num_steps}.mp4"
    try:
        import mediapy as media

        media.write_video(str(out_mp4), pred_u8_thwc, fps=int(args.fps))
    except Exception:
        import torchvision

        # torchvision expects (T,H,W,C) uint8 tensor
        torchvision.io.write_video(str(out_mp4), torch.from_numpy(pred_u8_thwc), fps=int(args.fps))

    # also save a quick frame strip
    strip = pred_u8_thwc[: min(8, pred_u8_thwc.shape[0])]
    strip_img = np.concatenate(list(strip), axis=1)  # (H, 8*W, C)
    out_png = out_mp4.with_suffix(".png")
    try:
        import mediapy as media

        media.write_image(str(out_png), strip_img)
    except Exception:
        from PIL import Image

        Image.fromarray(strip_img).save(out_png)

    print(f"Wrote {out_mp4}")
    print(f"Wrote {out_png}")


if __name__ == "__main__":
    main()

