"""
Test native video generation using the base Cosmos-Predict2-2B-Video2World model.
Generates a multi-frame video from a single LIBERO start frame.
"""

import os
import sys
import types
import pickle
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "9"

# === Mocks (same as test_libero_rollout.py) ===
te_torch_mock = types.ModuleType("transformer_engine_torch")
sys.modules["transformer_engine_torch"] = te_torch_mock

def _apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=False, cu_seqlens=None):
    import torch
    rot_dim = freqs.shape[-1]
    dtype = t.dtype
    if tensor_format == "bshd" and freqs.ndim == 4 and freqs.shape[1] == 1:
        freqs = freqs.permute(1, 0, 2, 3)
    cos_ = torch.cos(freqs).to(dtype)
    sin_ = torch.sin(freqs).to(dtype)
    t_rot = t[..., :rot_dim]
    t_pass = t[..., rot_dim:]
    x1, x2 = t_rot.chunk(2, dim=-1)
    rotated_half = torch.cat([-x2, x1], dim=-1)
    output = t_rot * cos_ + rotated_half * sin_
    if t_pass.shape[-1] > 0:
        return torch.cat([output, t_pass], dim=-1)
    return output

for mod_name in ["transformer_engine.pytorch", "transformer_engine.pytorch.attention",
                 "transformer_engine.pytorch.attention.rope", "transformer_engine.pytorch.distributed",
                 "transformer_engine.pytorch.float8_tensor"]:
    if mod_name not in sys.modules:
        sys.modules[mod_name] = types.ModuleType(mod_name)

import torch
class _RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

sys.modules["transformer_engine.pytorch"].RMSNorm = _RMSNorm
sys.modules["transformer_engine.pytorch.attention.rope"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].DotProductAttention = type("DotProductAttention", (), {})
class _FT: pass
sys.modules["transformer_engine.pytorch.float8_tensor"].Float8Tensor = _FT

_dummy = lambda *a, **k: None
flash_attn_mock = types.ModuleType("flash_attn")
flash_attn_mock.__version__ = "2.7.0"
flash_attn_mock.flash_attn_varlen_func = _dummy
flash_attn_mock.flash_attn_func = _dummy
sys.modules["flash_attn"] = flash_attn_mock
for submod in ["flash_attn.bert_padding", "flash_attn.flash_attn_interface",
               "flash_attn.layers.rotary", "flash_attn.ops.rms_norm",
               "flash_attn.ops", "flash_attn.layers"]:
    m = types.ModuleType(submod)
    for attr in ["pad_input", "unpad_input", "flash_attn_varlen_func", "index_first_axis", "flash_attn_func", "rms_norm"]:
        setattr(m, attr, _dummy)
    m.apply_rotary_emb = None
    sys.modules[submod] = m

import importlib, importlib.util, importlib.metadata
_of = importlib.util.find_spec
importlib.util.find_spec = lambda n, *a, **k: True if n == "flash_attn" else _of(n, *a, **k)
_ov = importlib.metadata.version
importlib.metadata.version = lambda n: "2.7.0" if n == "flash_attn" else _ov(n)

import megatron.core.parallel_state as _mps
if not hasattr(_mps, "is_initialized") and hasattr(_mps, "is_unitialized"):
    _mps.is_initialized = lambda: not _mps.is_unitialized()

# === Main ===
import glob
import imageio
from PIL import Image
from pathlib import Path
from cosmos_policy._src.predict2.utils.model_loader import load_model_from_checkpoint
from cosmos_policy._src.predict2.inference.get_t5_emb import get_text_embedding
from cosmos_policy.experiments.robot.cosmos_utils import (
    load_dataset_stats, init_t5_text_embeddings_cache, COSMOS_IMAGE_SIZE,
)

OUTPUT_DIR = Path("/data/cameron/vidgen/cosmos-policy/rollout_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)


def main():
    print("=" * 60)
    print("Cosmos Video2World - Native Video Generation from LIBERO frame")
    print("=" * 60)

    # Load the model with policy config (state_t=9 -> 33 pixel frames)
    # but using the BASE Video2World checkpoint for video generation
    config_name = "cosmos_predict2_2b_480p_libero__inference_only"
    base_ckpt = glob.glob(os.path.expanduser(
        "~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Video2World/snapshots/*/model-480p-16fps.pt"
    ))[0]

    print(f"\nLoading base Video2World model...")
    print(f"  Config: {config_name}")
    print(f"  Checkpoint: {base_ckpt}")

    model, config = load_model_from_checkpoint(
        experiment_name=config_name,
        s3_checkpoint_dir=base_ckpt,
        config_file="cosmos_policy/config/config.py",
        load_ema_to_reg=False,
        instantiate_ema=False,
    )
    model.eval()
    model = model.to("cuda")

    state_t = model.config.state_t  # 9 for policy config
    num_pixel_frames = (state_t - 1) * 4 + 1  # 33
    print(f"  state_t={state_t}, num_pixel_frames={num_pixel_frames}")
    print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # Load the sample observation
    with open("cosmos_policy/experiments/robot/libero/sample_libero_10_observation.pkl", "rb") as f:
        obs = pickle.load(f)
    task_desc = "put both the alphabet soup and the tomato sauce in the basket"

    # Load T5 embeddings
    t5_path = "nvidia/Cosmos-Policy-LIBERO-Predict2-2B/libero_t5_embeddings.pkl"
    init_t5_text_embeddings_cache(t5_path)
    from cosmos_policy.experiments.robot.cosmos_utils import get_t5_embedding_from_cache
    text_emb = get_t5_embedding_from_cache(task_desc)

    # Prepare the input image - use the model's actual resolution
    img = obs["primary_image"]  # (256, 256, 3) uint8
    # The model was trained with resize_online at 224px, but the base Video2World
    # model was trained at higher res. Let's try 480x480 which is closer to 480p.
    GEN_SIZE = 480
    img_resized = np.array(Image.fromarray(img).resize((GEN_SIZE, GEN_SIZE), Image.BILINEAR))

    # Build video tensor: first frame = image, rest = zeros
    H = W = GEN_SIZE
    video = torch.zeros(1, 3, num_pixel_frames, H, W, dtype=torch.uint8)
    # Set first frame
    img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1)  # (3, H, W)
    video[0, :, 0, :, :] = img_tensor

    # Build data batch for video generation
    # Include dummy policy indices (set to -1 = disabled) to satisfy the policy wrapper
    _neg1 = torch.tensor([-1], dtype=torch.int64).cuda()
    data_batch = {
        "dataset_name": "video_data",
        "video": video.cuda(),
        "t5_text_embeddings": text_emb.to(dtype=torch.bfloat16).cuda(),
        "fps": torch.tensor([16.0], dtype=torch.bfloat16).cuda(),
        "padding_mask": torch.zeros(1, 1, H, W, dtype=torch.bfloat16).cuda(),
        "num_conditional_frames": 1,  # Condition on first latent frame only
        # Dummy policy indices (all -1 = disabled)
        "proprio": None,
        "current_proprio_latent_idx": _neg1,
        "current_wrist_image_latent_idx": _neg1,
        "current_wrist_image2_latent_idx": _neg1,
        "current_image_latent_idx": _neg1,
        "current_image2_latent_idx": _neg1,
        "action_latent_idx": _neg1,
        "future_proprio_latent_idx": _neg1,
        "future_wrist_image_latent_idx": _neg1,
        "future_wrist_image2_latent_idx": _neg1,
        "future_image_latent_idx": _neg1,
        "future_image2_latent_idx": _neg1,
        "value_latent_idx": _neg1,
    }

    print(f"\n  Input video shape: {video.shape}")
    print(f"  Conditioning on 1 latent frame (first image)")
    print(f"  Generating {state_t - 1} latent frames = {num_pixel_frames} pixel frames...")

    # Generate using the BASE class method (skip policy-specific logic)
    from cosmos_policy._src.predict2.models.text2world_model import DiffusionModel
    with torch.inference_mode():
        sample = DiffusionModel.generate_samples_from_batch(
            model,  # pass the model instance, but call parent's method
            data_batch,
            n_sample=1,
            num_steps=35,  # Base model default
            seed=42,
            is_negative_prompt=False,
        )

        # Decode latent to video
        if isinstance(sample, tuple):
            sample = sample[0]
        print(f"  Generated latent shape: {sample.shape}")

        video_out = model.decode(sample)
        print(f"  Decoded video shape: {video_out.shape}")

    # Convert from [-1,1] to [0,255] uint8
    video_np = video_out[0].permute(1, 2, 3, 0).cpu().float().numpy()  # (T, H, W, 3)
    video_np = np.clip((video_np + 1) * 127.5, 0, 255).astype(np.uint8)
    print(f"  Video numpy shape: {video_np.shape}")

    # Save as mp4
    vid_path = str(OUTPUT_DIR / "cosmos_vid2world_sample_rollout.mp4")
    imageio.mimwrite(vid_path, list(video_np), fps=16, quality=8)
    print(f"\n  Saved video: {vid_path} ({video_np.shape[0]} frames @ 16fps)")

    # Save first, middle, last frames
    for label, idx in [("first", 0), ("middle", video_np.shape[0]//2), ("last", video_np.shape[0]-1)]:
        frame_path = str(OUTPUT_DIR / f"cosmos_vid2world_sample_{label}.png")
        Image.fromarray(video_np[idx]).save(frame_path)
        print(f"  Saved {label} frame: {frame_path}")

    print("\n" + "=" * 60)
    print("Done!")
    print("=" * 60)


if __name__ == "__main__":
    main()
