# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Image-to-video finetuning on cached DROID with Cosmos Predict2.5.
# Same CLI and dataset interface as diffusion_dino/train_img2vid.py:
#   --dataset droid_cache, --cache_dir, --batch_size, --steps, --vis_every,
#   --num_sampling_steps (default 5), --log_wandb, --name, --save_every, --ckpt_dir.
#
# Uses CachedClipDataset from vidgen (unified_video_action); converts 8-frame 256x256
# clips to Cosmos format (pad to model T, resize to model resolution, uint8, ai_caption).
#
# Usage (from cosmos-predict2.5 repo root):
#   uv run python scripts/train_droid_img2vid.py \
#     --dataset droid_cache --cache_dir /data/cameron/vidgen/dino_vid_model/vid_cache \
#     --log_wandb
#
# Checkpoint: default uses HF post-trained 2B (UUID 81edfebe-...). Resolves via HF_HOME
# (e.g. /data/cameron/vidgen/.cache/huggingface). Override with --ckpt_path /path/to/model.pt
# if needed.
#
# For single-GPU you can run with python; for multi-GPU use torchrun --nproc_per_node=N.
#
# If you see ImportError/undefined symbol for flash_attn or transformer_engine (.so), your
# PyTorch version doesn't match the prebuilt extensions. Reinstall from the env that has torch:
#   pip install --no-cache-dir --no-build-isolation flash-attn
#   pip install --no-cache-dir --no-build-isolation 'transformer-engine[pytorch]'
# (Use the same Python/pip that has torch, e.g. activate your venv/conda first.)

import argparse
import os
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# Vidgen/unified_video_action for same dataset as diffusion_dino
# Script lives in cosmos-predict2.5/scripts/; vidgen is typically the parent of cosmos-predict2.5
_script_dir = Path(__file__).resolve().parent
COSMOS_ROOT = _script_dir.parent  # cosmos-predict2.5
vidgen_root = COSMOS_ROOT.parent  # vidgen (if cosmos-predict2.5 is inside vidgen)
uva_root = vidgen_root / "unified_video_action"
for p in [vidgen_root, uva_root]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p))
if str(COSMOS_ROOT) not in sys.path:
    sys.path.insert(0, str(COSMOS_ROOT))

_DISABLE_FLASH_ATTN = os.environ.get("VIDGEN_DISABLE_FLASH_ATTN", "0") == "1"
if _DISABLE_FLASH_ATTN:
    # Some environments have a broken flash_attn build (PyTorch ABI mismatch). If you need to
    # force Cosmos/Transformers to avoid importing it, set VIDGEN_DISABLE_FLASH_ATTN=1.
    try:
        import transformers.utils.import_utils as _tf_import_utils

        _orig_pkg = _tf_import_utils._is_package_available

        def _no_flash_attn(pkg_name, return_version=False):
            if pkg_name == "flash_attn":
                return (False, "N/A") if return_version else False
            return _orig_pkg(pkg_name, return_version)

        _tf_import_utils._is_package_available = _no_flash_attn
        try:
            _tf_import_utils.is_flash_attn_2_available.cache_clear()
        except Exception:
            pass
        _tf_import_utils.is_flash_attn_2_available = lambda: False
        if hasattr(_tf_import_utils, "is_flash_attn_greater_or_equal"):
            try:
                _tf_import_utils.is_flash_attn_greater_or_equal.cache_clear()
            except Exception:
                pass
            _tf_import_utils.is_flash_attn_greater_or_equal = lambda _: False
        import transformers.utils as _tf_utils

        if hasattr(_tf_utils, "is_flash_attn_2_available"):
            _tf_utils.is_flash_attn_2_available = lambda: False
        if hasattr(_tf_utils, "is_flash_attn_greater_or_equal"):
            _tf_utils.is_flash_attn_greater_or_equal = lambda _: False
    except Exception:
        pass

NUM_FRAMES = 8  # same as diffusion_dino (CachedClipDataset with 8 frames)
DEFAULT_DROID_CACHE_DIR = Path("/data/cameron/vidgen/dino_vid_model/vid_cache")
DEFAULT_CKPT_DIR = Path("ckpt_cosmos_droid_img2vid")

# 2B post-trained (I2V) from HF: use UUID so checkpoint_db resolves to HF cache, and matching experiment
DEFAULT_CKPT_PATH = "81edfebe-bd6a-4039-8c1d-737df1a790bf"  # nvidia/Cosmos-Predict2.5-2B/base/post-trained
DEFAULT_EXPERIMENT = "Stage-c_pt_4-Index-2-Size-2B-Res-720-Fps-16-Note-rf_with_edm_ckpt"
# Or local path (if already downloaded): e.g. /data/cameron/vidgen/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/15a82a2ec231bc318692aa0456a36537c806e7d4/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt


def get_cosmos_video_spec(model):
    """Return (num_pixel_frames, height, width) the model expects for video."""
    from cosmos_predict2._src.predict2.datasets.utils import VIDEO_RES_SIZE_INFO

    tokenizer = model.tokenizer
    req_T = tokenizer.pixel_chunk_duration
    res_key = getattr(model.config, "resolution", "720")
    if isinstance(res_key, str) and res_key in VIDEO_RES_SIZE_INFO:
        ar_info = VIDEO_RES_SIZE_INFO[res_key]
        if isinstance(ar_info, dict):
            req_H, req_W = ar_info.get("1,1", (960, 960))
        else:
            req_H, req_W = 960, 960
    else:
        req_H, req_W = 704, 1280
    return req_T, req_H, req_W


def cache_batch_to_cosmos_data_batch(
    video_batch: torch.Tensor,
    device: torch.device,
    req_T: int,
    req_H: int,
    req_W: int,
    caption: str = "Robot manipulation.",
) -> dict:
    """
    Convert cache batch (B, C, T, H, W) in [-1, 1] to Cosmos data_batch.
    Pads temporal dim to req_T (repeat last frame), resizes to (req_H, req_W), converts to uint8.
    """
    B, C, T, H, W = video_batch.shape
    # [-1, 1] -> [0, 255] uint8
    video = ((video_batch.clamp(-1.0, 1.0) + 1.0) * 127.5).round().clamp(0, 255).to(torch.uint8)
    if T < req_T:
        last = video[:, :, -1:].expand(-1, -1, req_T - T, -1, -1)
        video = torch.cat([video, last], dim=2)
    elif T > req_T:
        video = video[:, :, :req_T]
    # (B, C, req_T, H, W) -> resize to (req_H, req_W)
    if (H, W) != (req_H, req_W):
        # F.interpolate expects (N, C, D, H, W) for 5D
        video = video.float() / 255.0
        video = F.interpolate(
            video,
            size=(req_T, req_H, req_W),
            mode="trilinear",
            align_corners=False,
        )
        video = (video * 255.0).round().clamp(0, 255).to(torch.uint8)
    video = video.to(device)

    # Cosmos conditioner expects fps in the batch (float tensor of shape (B,)).
    # Default experiment is 16fps; change if you want different conditioning.
    fps = torch.full((B,), 16.0, device=device, dtype=torch.float32)

    # Cosmos conditioner also expects padding_mask (B, 1, H, W). We have no padding, so zeros.
    padding_mask = torch.zeros((B, 1, req_H, req_W), device=device, dtype=torch.float32)
    return {
        "video": video,
        "ai_caption": [caption] * B,
        "fps": fps,
        "padding_mask": padding_mask,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Cosmos Predict2.5 image-to-video finetuning on cached DROID (same interface as diffusion_dino/train_img2vid)."
    )
    parser.add_argument("--data_root", type=Path, default=Path("/data/cameron/keygrip/scratch"))
    parser.add_argument(
        "--dataset",
        type=str,
        default="droid_cache",
        choices=["self_collected", "droid_cache"],
        help="Dataset source; use droid_cache for cached .pt clips.",
    )
    parser.add_argument(
        "--cache_dir",
        type=Path,
        default=DEFAULT_DROID_CACHE_DIR,
        help="DROID cached clips dir (e.g. dino_vid_model/vid_cache).",
    )
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--steps", type=int, default=100000000000)
    parser.add_argument("--vis_every", type=int, default=100, help="Log pred/GT video to wandb every N steps")
    parser.add_argument(
        "--num_sampling_steps",
        type=str,
        default="5",
        help="Number of denoising steps for sampling (used for vis). Default 5.",
    )
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_wandb", action="store_true")
    parser.add_argument("--name", type=str, default="cosmos_droid_img2vid")
    parser.add_argument("--save_every", type=int, default=500)
    parser.add_argument("--ckpt_dir", type=Path, default=DEFAULT_CKPT_DIR)
    parser.add_argument(
        "--hf_token",
        type=str,
        default=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
        help="Optional Hugging Face token (needed for gated Cosmos repos). If set, exported to HF_TOKEN/HUGGINGFACE_HUB_TOKEN for checkpoint downloads.",
    )
    parser.add_argument(
        "--resolution",
        type=str,
        default="",
        help="Optional Cosmos resolution override (e.g. 256, 480, 512, 720). Lowering this can drastically reduce VRAM.",
    )
    # Cosmos-specific
    parser.add_argument(
        "--ckpt_path",
        type=str,
        default=DEFAULT_CKPT_PATH,
        help="Cosmos checkpoint: UUID (e.g. 81edfebe-... for HF post-trained), hf://org/repo/path, or local path.",
    )
    parser.add_argument(
        "--experiment",
        type=str,
        default=DEFAULT_EXPERIMENT,
        help="Cosmos experiment name (must match config for ckpt_path).",
    )
    args = parser.parse_args()

    if not args.ckpt_path:
        raise SystemExit("--ckpt_path is required (or set default in script).")

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    num_sampling_steps = int(args.num_sampling_steps)

    # Ensure downstream `uvx hf download ... --token ...` can authenticate.
    token = (args.hf_token or "").strip() or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
    if token:
        os.environ["HF_TOKEN"] = token
        os.environ["HUGGINGFACE_HUB_TOKEN"] = token
        print(f"HF token detected (len={len(token)}).")
    else:
        # Not fatal if everything is cached, but gated repos will 403 without it.
        print("Warning: no HF token found (set HF_TOKEN/HUGGINGFACE_HUB_TOKEN or pass --hf_token).")

    # Optional distributed init for single-node multi-GPU (optional)
    if not torch.distributed.is_initialized() and "RANK" in os.environ:
        torch.distributed.init_process_group(backend="nccl")

    # Dataset: same as diffusion_dino (import dataset module only to avoid simple_uva loader/dill)
    import importlib.util
    _dataset_path = uva_root / "simple_uva" / "dataset.py"
    _spec = importlib.util.spec_from_file_location("_uva_dataset", _dataset_path)
    _dataset_module = importlib.util.module_from_spec(_spec)
    _spec.loader.exec_module(_dataset_module)
    CachedClipDataset = _dataset_module.CachedClipDataset
    collate_batch = _dataset_module.collate_batch

    if args.dataset == "droid_cache":
        cache_path = args.cache_dir
    else:
        cache_path = getattr(args.cache_dir, "parent", Path(args.cache_dir).parent) / "vid_cache_keygrip"
    dataset = CachedClipDataset(str(cache_path), num_frames=NUM_FRAMES)
    print(f"Train dataset (droid_cache): {len(dataset)} clips from {cache_path}")
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_batch,
        pin_memory=True,
        persistent_workers=args.workers > 0,
        prefetch_factor=4 if args.workers > 0 else None,
    )

    # Load Cosmos model
    from cosmos_predict2._src.predict2.utils.model_loader import load_model_from_checkpoint

    # Ensure HF/post-trained checkpoints are registered (cosmos_oss)
    try:
        from cosmos_oss.checkpoints_predict2 import register_checkpoints
        register_checkpoints()
    except Exception:
        pass

    config_file = "cosmos_predict2/_src/predict2/configs/video2world/config.py"
    experiment_opts = ["data_train=mock_video", "data_val=mock"]
    if args.resolution:
        # Matches configs that use model.config.resolution to select VIDEO_RES_SIZE_INFO.
        experiment_opts.append(f"model.config.resolution={args.resolution}")
    model, config = load_model_from_checkpoint(
        experiment_name=args.experiment,
        s3_checkpoint_dir=args.ckpt_path,
        config_file=config_file,
        # Many upstream Cosmos experiments reference internal data configs via Hydra.
        # We don't use the Hydra dataloaders here (we provide our own CachedClipDataset),
        # so default to a lightweight public config to satisfy composition.
        experiment_opts=experiment_opts,
        enable_fsdp=False,
        load_ema_to_reg=False,
        instantiate_ema=True,
        seed=0,
        local_cache_dir=None,
    )
    model.train()
    model.to(device)

    req_T, req_H, req_W = get_cosmos_video_spec(model)
    print(f"Cosmos video spec: T={req_T}, H={req_H}, W={req_W}")

    # Optimizer over trainable params
    trainable = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(trainable, lr=args.lr)
    args.ckpt_dir = Path(args.ckpt_dir)
    args.ckpt_dir.mkdir(parents=True, exist_ok=True)

    if args.log_wandb:
        import wandb

        if os.environ.get("WANDB_API_KEY"):
            wandb.login(key=os.environ["WANDB_API_KEY"])
        wandb.init(project="cosmos_droid_img2vid", config=vars(args), name=args.name, mode="online")

    global_step = 0
    pbar = tqdm(total=args.steps, desc="steps", unit="step")

    for video in loader:
        video = video.to(device, non_blocking=True)
        # (B, C, T, H, W) from collate; T=NUM_FRAMES, H=W=256
        data_batch = cache_batch_to_cosmos_data_batch(video, device, req_T, req_H, req_W)

        autocast_ctx = (
            torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            if device.type == "cuda"
            else torch.autocast(device_type="cpu", dtype=torch.bfloat16)
        )
        with autocast_ctx:
            output_batch, loss = model.training_step(data_batch, global_step)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if global_step % 50 == 0:
            print(f"step {global_step} loss {loss.item():.4f}")
        if args.log_wandb:
            import wandb

            wandb.log({"train/loss": float(loss.item())}, step=global_step)

        # Vis: 5-step sampling, log pred vs GT 8-frame videos (same as diffusion_dino)
        if args.log_wandb and args.vis_every and (global_step % args.vis_every == 0):
            import tempfile
            import torchvision
            import wandb

            print("Sampling video (zero actions)...")
            model.eval()
            with torch.no_grad():
                cond_video = video[:1].clone()
                vis_batch = cache_batch_to_cosmos_data_batch(
                    cond_video, device, req_T, req_H, req_W
                )
                # generate_samples_from_batch expects text embeddings to already be present.
                if getattr(model, "text_encoder", None) is not None:
                    text_embeddings = model.text_encoder.compute_text_embeddings_online(
                        vis_batch, "ai_caption"
                    )
                    vis_batch["t5_text_embeddings"] = text_embeddings
                    vis_batch["t5_text_mask"] = torch.ones(
                        text_embeddings.shape[0],
                        text_embeddings.shape[1],
                        device=text_embeddings.device,
                    )
                try:
                    with autocast_ctx:
                        pred_latent = model.generate_samples_from_batch(
                            vis_batch,
                            guidance=1.0,
                            seed=42,
                            n_sample=1,
                            num_steps=num_sampling_steps,
                        )
                        pred_video = model.decode(pred_latent)
                except Exception as e:
                    print(f"Sampling failed: {e}")
                    model.train()
                    global_step += 1
                    pbar.update(1)
                    continue
            model.train()

            # Log cond frame and first pred frame
            cond_frame = cond_video[0, :, 0].cpu()
            cond_img = ((cond_frame + 1.0) / 2.0).clamp(0, 1)
            pred_frame0 = pred_video[0, :, 0].float().cpu()
            pred_frame0 = ((pred_frame0 + 1.0) / 2.0).clamp(0, 1)
            wandb.log(
                {
                    "vis/cond_frame": wandb.Image(cond_img.permute(1, 2, 0).numpy()),
                    "vis/pred_frame0": wandb.Image(pred_frame0.permute(1, 2, 0).numpy()),
                },
                step=global_step,
            )

            try:
                # Predicted video: use first NUM_FRAMES for comparison
                pred_np = pred_video[0].float().cpu()
                pred_np = ((pred_np + 1.0) / 2.0).clamp(0, 1)
                n_show = min(NUM_FRAMES, pred_np.shape[1])
                pred_np = pred_np[:, :n_show].permute(1, 0, 2, 3)  # (T, C, H, W)
                pred_frames = (pred_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                gt_np = ((cond_video[0].detach().cpu() + 1.0) / 2.0).clamp(0, 1)
                gt_np = gt_np.permute(1, 0, 2, 3)
                gt_frames = (gt_np.permute(0, 2, 3, 1).numpy() * 255).astype("uint8")

                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_pred:
                    pred_path = f_pred.name
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f_gt:
                    gt_path = f_gt.name

                torchvision.io.write_video(pred_path, torch.from_numpy(pred_frames), fps=4)
                torchvision.io.write_video(gt_path, torch.from_numpy(gt_frames), fps=4)

                wandb.log(
                    {
                        "vis/pred_video": wandb.Video(pred_path, format="mp4"),
                        "vis/gt_video": wandb.Video(gt_path, format="mp4"),
                    },
                    step=global_step,
                )
                Path(pred_path).unlink(missing_ok=True)
                Path(gt_path).unlink(missing_ok=True)
            except Exception as e:
                print(f"Warning: could not log videos to wandb: {e}")

        global_step += 1
        pbar.update(1)

        if args.save_every and (global_step % args.save_every == 0):
            ckpt = {
                "step": global_step,
                "model": model.state_dict(),
                "optimizer": opt.state_dict(),
            }
            path = args.ckpt_dir / f"cosmos_droid_img2vid_step{global_step}.pt"
            torch.save(ckpt, path)
            print(f"Saved {path}")

        if global_step >= args.steps:
            pbar.close()
            if args.log_wandb:
                import wandb
                wandb.finish()
            return


if __name__ == "__main__":
    main()
