"""
Test Cosmos Policy on our LIBERO dataset.
Loads pretrained model, feeds start frames from /data/libero,
generates predicted future images, and saves first/middle/last frames for visual inspection.
"""

import os
import sys
import types
import pickle
import numpy as np

# Set GPU before any CUDA imports
os.environ["CUDA_VISIBLE_DEVICES"] = "9"

# Mock transformer_engine_torch (only needed for training/FusedAdam, not inference)
te_torch_mock = types.ModuleType("transformer_engine_torch")
sys.modules["transformer_engine_torch"] = te_torch_mock

# Mock transformer_engine.pytorch and submodules (the compiled .so is missing)
# We provide a real implementation of apply_rotary_pos_emb
def _apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=False, cu_seqlens=None):
    """Apply rotary positional embedding - matches TE's apply_rotary_pos_emb.
    Uses the _rotate_half convention (halved, not interleaved):
      output = t * cos(freqs) + rotate_half(t) * sin(freqs)
    where rotate_half splits the last dim in half: [-x2, x1]

    freqs shape: (S, 1, 1, rot_dim), contains raw angle values.
    t shape: (B, S, H, D) for bshd format, (S, B, H, D) for sbhd."""
    import torch
    rot_dim = freqs.shape[-1]
    dtype = t.dtype

    # Reshape freqs for proper broadcasting with bshd format
    # freqs: (S, 1, 1, D) needs to become (1, S, 1, D) to align with (B, S, H, D)
    if tensor_format == "bshd" and freqs.ndim == 4 and freqs.shape[1] == 1:
        freqs = freqs.permute(1, 0, 2, 3)  # (S, 1, 1, D) -> (1, S, 1, D)

    cos_ = torch.cos(freqs).to(dtype)
    sin_ = torch.sin(freqs).to(dtype)

    # Split into rotation and passthrough parts
    t_rot = t[..., :rot_dim]
    t_pass = t[..., rot_dim:]

    # _rotate_half: split last dim in half, negate-and-swap
    x1, x2 = t_rot.chunk(2, dim=-1)
    rotated_half = torch.cat([-x2, x1], dim=-1)

    # Apply rotation: t * cos + rotate_half(t) * sin
    output = t_rot * cos_ + rotated_half * sin_

    if t_pass.shape[-1] > 0:
        return torch.cat([output, t_pass], dim=-1)
    return output

# Create mock module hierarchy for transformer_engine.pytorch (megatron-core 0.4.0 compatible)
import torch

for mod_name in [
    "transformer_engine.pytorch",
    "transformer_engine.pytorch.attention",
    "transformer_engine.pytorch.attention.rope",
    "transformer_engine.pytorch.distributed",
    "transformer_engine.pytorch.float8_tensor",
]:
    if mod_name not in sys.modules:
        sys.modules[mod_name] = types.ModuleType(mod_name)

sys.modules["transformer_engine.pytorch.attention.rope"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].apply_rotary_pos_emb = _apply_rotary_pos_emb
sys.modules["transformer_engine.pytorch.attention"].DotProductAttention = type("DotProductAttention", (), {})

class _Float8Tensor:
    pass
sys.modules["transformer_engine.pytorch.float8_tensor"].Float8Tensor = _Float8Tensor

class _RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
sys.modules["transformer_engine.pytorch"].RMSNorm = _RMSNorm

# Patch megatron.core.parallel_state for API compatibility
import megatron.core.parallel_state as _mps
if not hasattr(_mps, 'is_initialized') and hasattr(_mps, 'is_unitialized'):
    _mps.is_initialized = lambda: not _mps.is_unitialized()

# Mock flash_attn (needed for Qwen VL encoder import but we use T5 for LIBERO)
_dummy = lambda *a, **k: None
flash_attn_mock = types.ModuleType("flash_attn")
flash_attn_mock.__version__ = "2.7.0"
flash_attn_mock.flash_attn_varlen_func = _dummy
flash_attn_mock.flash_attn_func = _dummy
sys.modules["flash_attn"] = flash_attn_mock
# Also mock sub-modules that may be imported
for submod in ["flash_attn.bert_padding", "flash_attn.flash_attn_interface",
               "flash_attn.layers.rotary", "flash_attn.ops.rms_norm",
               "flash_attn.ops", "flash_attn.layers"]:
    mock = types.ModuleType(submod)
    for attr in ["pad_input", "unpad_input", "flash_attn_varlen_func",
                 "index_first_axis", "flash_attn_func", "rms_norm"]:
        setattr(mock, attr, _dummy)
    mock.apply_rotary_emb = None
    sys.modules[submod] = mock

# Patch transformers to think flash_attn_2 is available
import importlib
import importlib.util
import importlib.metadata

_orig_util_find = importlib.util.find_spec
def _patched_find_spec(name, *args, **kwargs):
    if name == "flash_attn":
        return True  # pretend it exists
    return _orig_util_find(name, *args, **kwargs)
importlib.util.find_spec = _patched_find_spec

_orig_metadata_version = importlib.metadata.version
def _patched_metadata_version(name):
    if name == "flash_attn":
        return "2.7.0"
    return _orig_metadata_version(name)
importlib.metadata.version = _patched_metadata_version

import torch
import h5py
from PIL import Image
from pathlib import Path
from dataclasses import dataclass
from typing import Optional

# Patch get_checkpoint_path - let the real HF download handle everything
# (requires HF token with gated repo access)


@dataclass
class PolicyEvalConfig:
    """Minimal config matching the full PolicyEvalConfig but without libero dependency."""
    suite: str = "libero"
    model_family: str = "cosmos"
    config: str = ""
    ckpt_path: str = ""
    planning_model_config_name: str = ""
    planning_model_ckpt_path: str = ""
    config_file: str = "cosmos_policy/config/config.py"
    use_third_person_image: bool = True
    num_third_person_images: int = 1
    use_wrist_image: bool = True
    num_wrist_images: int = 1
    use_proprio: bool = True
    flip_images: bool = True
    use_variance_scale: bool = False
    use_jpeg_compression: bool = True
    ar_future_prediction: bool = False
    ar_value_prediction: bool = False
    ar_qvalue_prediction: bool = False
    num_denoising_steps_action: int = 5
    num_denoising_steps_future_state: int = 1
    num_denoising_steps_value: int = 1
    unnormalize_actions: bool = True
    normalize_proprio: bool = True
    dataset_stats_path: str = ""
    t5_text_embeddings_path: str = ""
    trained_with_image_aug: bool = True
    chunk_size: int = 16
    num_open_loop_steps: int = 16
    deterministic: bool = True
    deterministic_reset: bool = False
    deterministic_reset_seed: Optional[int] = None
    use_ensemble_future_state_predictions: bool = False
    num_future_state_predictions_in_ensemble: int = 3
    future_state_ensemble_aggregation_scheme: str = "average"
    use_ensemble_value_predictions: bool = False
    num_value_predictions_in_ensemble: int = 5
    value_ensemble_aggregation_scheme: str = "average"
    search_depth: int = 1
    mask_current_state_action_for_value_prediction: bool = False
    mask_future_state_for_qvalue_prediction: bool = False
    num_queries_best_of_n: int = 1
    use_parallel_inference: bool = False
    available_gpus: str = "0"
    parallel_timeout: int = 15
    task_suite_name: str = "libero_spatial"
    num_trials_per_task: int = 50
    initial_states_path: str = "DEFAULT"
    env_img_res: int = 256
    local_log_dir: str = "./experiments/logs"
    run_id_note: Optional[str] = None
    use_wandb: bool = False
    wandb_entity: str = ""
    wandb_project: str = ""
    seed: int = 7
    randomize_seed: bool = False
    data_collection: bool = False


from cosmos_policy.experiments.robot.cosmos_utils import (
    get_action,
    get_model,
    load_dataset_stats,
    init_t5_text_embeddings_cache,
)

OUTPUT_DIR = Path("/data/cameron/vidgen/cosmos-policy/rollout_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)


def load_observation_from_hdf5(hdf5_path, demo_idx=0, frame_idx=0):
    """Load an observation from our raw LIBERO HDF5 files."""
    with h5py.File(hdf5_path, "r") as f:
        demo_key = f"data/demo_{demo_idx}"
        # Load images (128x128 -> resize to 256x256 to match Cosmos training data)
        primary_img = f[f"{demo_key}/obs/agentview_rgb"][frame_idx]  # (128, 128, 3)
        wrist_img = f[f"{demo_key}/obs/eye_in_hand_rgb"][frame_idx]  # (128, 128, 3)

        # LIBERO simulator renders images upside-down - flip to right-side up
        primary_img = np.flipud(primary_img).copy()
        wrist_img = np.flipud(wrist_img).copy()

        # Resize to 256x256 (Cosmos training resolution)
        primary_img = np.array(Image.fromarray(primary_img).resize((256, 256), Image.BILINEAR))
        wrist_img = np.array(Image.fromarray(wrist_img).resize((256, 256), Image.BILINEAR))

        # Build proprio: gripper_states(2) + ee_pos(3) + ee_quat(4) = 9
        gripper = f[f"{demo_key}/obs/gripper_states"][frame_idx]  # (2,)
        ee_pos = f[f"{demo_key}/obs/ee_pos"][frame_idx]  # (3,)
        # For quaternion, use robot_states which is (9,) = gripper(2) + ee_pos(3) + ee_quat(4)
        robot_states = f[f"{demo_key}/robot_states"][frame_idx]  # (9,)
        proprio = robot_states.astype(np.float64)

        observation = {
            "primary_image": primary_img,
            "wrist_image": wrist_img,
            "proprio": proprio,
        }
        return observation


def run_sample_test(cfg, model, dataset_stats):
    """Test with the bundled sample observation first."""
    print("\n=== Testing with sample observation ===")
    sample_path = "cosmos_policy/experiments/robot/libero/sample_libero_10_observation.pkl"
    with open(sample_path, "rb") as f:
        observation = pickle.load(f)

    task_description = "put both the alphabet soup and the tomato sauce in the basket"

    print(f"  Primary image: {observation['primary_image'].shape}")
    print(f"  Wrist image: {observation['wrist_image'].shape}")
    print(f"  Proprio: {observation['proprio'].shape}")

    print("  Running inference...")
    action_return_dict = get_action(
        cfg, model, dataset_stats, observation, task_description,
        num_denoising_steps_action=cfg.num_denoising_steps_action,
        generate_future_state_and_value_in_parallel=True,
    )

    print(f"  Generated action chunk: {len(action_return_dict['actions'])} steps")
    print(f"  Value prediction: {action_return_dict['value_prediction']:.3f}")

    # Save images
    prefix = OUTPUT_DIR / "sample"
    Image.fromarray(observation["primary_image"]).save(f"{prefix}_input_primary.png")
    Image.fromarray(observation["wrist_image"]).save(f"{prefix}_input_wrist.png")

    future_preds = action_return_dict["future_image_predictions"]
    if "future_image" in future_preds:
        Image.fromarray(future_preds["future_image"]).save(f"{prefix}_future_primary.png")
        print(f"  Saved future primary image")
    if "future_wrist_image" in future_preds:
        Image.fromarray(future_preds["future_wrist_image"]).save(f"{prefix}_future_wrist.png")
        print(f"  Saved future wrist image")

    print(f"  Output saved to {prefix}_*.png")
    return action_return_dict


def run_libero_test(cfg, model, dataset_stats):
    """Test with our actual LIBERO dataset."""
    print("\n=== Testing with our LIBERO dataset ===")

    libero_dir = Path("/data/libero/libero_spatial")
    hdf5_files = sorted(libero_dir.glob("*.hdf5"))

    if not hdf5_files:
        print("  No HDF5 files found in /data/libero/libero_spatial/")
        return

    # Get task descriptions from parsed data
    parsed_dir = Path("/data/libero/parsed_libero/libero_spatial")

    for i, hdf5_file in enumerate(hdf5_files[:3]):  # Test first 3 tasks
        # Try to get task description
        desc_file = parsed_dir / f"task_{i}_description.txt"
        if desc_file.exists():
            task_name = desc_file.read_text().strip()
        else:
            task_name = hdf5_file.stem.replace("_demo", "").replace("_", " ")

        print(f"\n  Task {i}: {task_name}")
        print(f"  File: {hdf5_file.name}")

        with h5py.File(str(hdf5_file), "r") as f:
            num_frames = f["data/demo_0/obs/agentview_rgb"].shape[0]
            print(f"  Total frames in demo_0: {num_frames}")

        frame_indices = [0, num_frames // 2, num_frames - 1]
        safe_name = hdf5_file.stem[:50]

        for fi, frame_idx in enumerate(frame_indices):
            frame_label = ["first", "middle", "last"][fi]
            print(f"  Processing {frame_label} frame (idx={frame_idx})...")

            obs = load_observation_from_hdf5(str(hdf5_file), demo_idx=0, frame_idx=frame_idx)

            prefix = OUTPUT_DIR / f"libero_{safe_name}_f{frame_idx}"
            Image.fromarray(obs["primary_image"]).save(f"{prefix}_input_primary.png")
            Image.fromarray(obs["wrist_image"]).save(f"{prefix}_input_wrist.png")

            action_return_dict = get_action(
                cfg, model, dataset_stats, obs, task_name,
                num_denoising_steps_action=cfg.num_denoising_steps_action,
                generate_future_state_and_value_in_parallel=True,
            )

            future_preds = action_return_dict["future_image_predictions"]
            if "future_image" in future_preds:
                Image.fromarray(future_preds["future_image"]).save(f"{prefix}_future_primary.png")
            if "future_wrist_image" in future_preds:
                Image.fromarray(future_preds["future_wrist_image"]).save(f"{prefix}_future_wrist.png")

            val = action_return_dict.get("value_prediction", None)
            print(f"    Value: {val:.3f}" if val is not None else "    Value: N/A")
            print(f"    Actions: {len(action_return_dict['actions'])} steps")

    print(f"\nAll outputs saved to {OUTPUT_DIR}")


def run_video_rollout(cfg, model, dataset_stats, num_rollout_steps=20):
    """Generate autoregressive video rollouts: feed predicted future back as input.

    The Cosmos Policy model generates 1 future frame per query, so we do
    autoregressive rollout. To prevent the progressive zoom caused by the
    90% center crop in prepare_images_for_model being applied each step,
    we disable the crop for predicted images (they're already at 224x224
    and don't need the train-time augmentation correction).
    """
    import imageio

    # Temporarily disable the center crop for AR rollout steps.
    # The crop is only needed for raw env images, not model-generated ones.
    cfg_nocrop = PolicyEvalConfig(**{k: getattr(cfg, k) for k in cfg.__dataclass_fields__})
    cfg_nocrop.trained_with_image_aug = False  # disables center crop
    cfg_nocrop.use_jpeg_compression = False    # skip re-compressing generated images

    print(f"\n=== Generating video rollouts ({num_rollout_steps} steps each) ===")

    def _rollout(initial_obs, task_desc, name_prefix):
        """Run AR rollout from initial observation."""
        # First step: use full cfg (with crop) for the raw env image
        result = get_action(
            cfg, model, dataset_stats, initial_obs, task_desc,
            num_denoising_steps_action=cfg.num_denoising_steps_action,
            generate_future_state_and_value_in_parallel=True,
        )
        future = result["future_image_predictions"]
        future_img = future["future_image"]
        future_wrist = future["future_wrist_image"]
        val = result.get("value_prediction", 0)
        print(f"    Step 1/{num_rollout_steps}  value={val:.3f}")

        # Collect frames at output resolution (224x224)
        # Apply the same center crop to the GT start frame so it matches the model's view
        start_img = np.array(Image.fromarray(initial_obs["primary_image"]).resize((224, 224), Image.BILINEAR))
        # Match the 90% area center crop that prepare_images_for_model applies
        import torchvision.transforms.functional as TF
        crop_size = int(224 * 0.9 ** 0.5)  # 212
        start_tensor = torch.from_numpy(start_img).permute(2, 0, 1)  # (3,H,W)
        start_tensor = TF.center_crop(start_tensor, crop_size)
        start_tensor = TF.resize(start_tensor, [224, 224], antialias=True)
        start_img = start_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
        frames = [start_img, future_img.copy()]

        # Subsequent steps: use cfg_nocrop since inputs are already generated images
        for step in range(1, num_rollout_steps):
            # The model outputs 224x224. Feed directly without resize.
            # prepare_images_for_model will resize to 224 (no-op) and skip crop.
            h, w = future_img.shape[:2]
            if h != 256 or w != 256:
                inp_img = np.array(Image.fromarray(future_img).resize((256, 256), Image.BILINEAR))
                inp_wrist = np.array(Image.fromarray(future_wrist).resize((256, 256), Image.BILINEAR))
            else:
                inp_img = future_img
                inp_wrist = future_wrist

            current_obs = {
                "primary_image": inp_img,
                "wrist_image": inp_wrist,
                "proprio": initial_obs["proprio"],
            }

            result = get_action(
                cfg_nocrop, model, dataset_stats, current_obs, task_desc,
                num_denoising_steps_action=cfg.num_denoising_steps_action,
                generate_future_state_and_value_in_parallel=True,
            )
            future = result["future_image_predictions"]
            future_img = future["future_image"]
            future_wrist = future["future_wrist_image"]
            val = result.get("value_prediction", 0)
            print(f"    Step {step+1}/{num_rollout_steps}  value={val:.3f}")

            frames.append(future_img.copy())

        # Save mp4
        vid_path = str(OUTPUT_DIR / f"{name_prefix}_rollout.mp4")
        imageio.mimwrite(vid_path, frames, fps=4, quality=8)
        print(f"  Saved video: {vid_path} ({len(frames)} frames)")

        # Save first, middle, last frames
        for label, idx in [("first", 0), ("middle", len(frames)//2), ("last", len(frames)-1)]:
            Image.fromarray(frames[idx]).save(str(OUTPUT_DIR / f"{name_prefix}_rollout_{label}.png"))

    # --- 1) Sample observation rollout ---
    print("\n  Video rollout: sample observation")
    sample_path = "cosmos_policy/experiments/robot/libero/sample_libero_10_observation.pkl"
    with open(sample_path, "rb") as f:
        obs = pickle.load(f)
    _rollout(obs, "put both the alphabet soup and the tomato sauce in the basket", "sample")

    # --- 2) LIBERO dataset rollouts ---
    libero_dir = Path("/data/libero/libero_spatial")
    hdf5_files = sorted(libero_dir.glob("*.hdf5"))
    parsed_dir = Path("/data/libero/parsed_libero/libero_spatial")

    for i, hdf5_file in enumerate(hdf5_files[:2]):
        desc_file = parsed_dir / f"task_{i}_description.txt"
        task_name = desc_file.read_text().strip() if desc_file.exists() else hdf5_file.stem.replace("_demo", "").replace("_", " ")
        safe_name = hdf5_file.stem[:40]

        print(f"\n  Video rollout: Task {i} - {task_name}")
        initial_obs = load_observation_from_hdf5(str(hdf5_file), demo_idx=0, frame_idx=0)
        _rollout(initial_obs, task_name, f"libero_{safe_name}")

    print(f"\nAll video rollouts saved to {OUTPUT_DIR}")


def main():
    print("=" * 60)
    print("Cosmos Policy - LIBERO Video Rollout Test")
    print("=" * 60)

    cfg = PolicyEvalConfig(
        config="cosmos_predict2_2b_480p_libero__inference_only",
        ckpt_path="nvidia/Cosmos-Policy-LIBERO-Predict2-2B",
        config_file="cosmos_policy/config/config.py",
        dataset_stats_path="nvidia/Cosmos-Policy-LIBERO-Predict2-2B/libero_dataset_statistics.json",
        t5_text_embeddings_path="nvidia/Cosmos-Policy-LIBERO-Predict2-2B/libero_t5_embeddings.pkl",
        use_wrist_image=True,
        use_proprio=True,
        normalize_proprio=True,
        unnormalize_actions=True,
        chunk_size=16,
        num_open_loop_steps=16,
        trained_with_image_aug=True,
        use_jpeg_compression=True,
        flip_images=True,
        num_denoising_steps_action=5,
        num_denoising_steps_future_state=1,
        num_denoising_steps_value=1,
    )

    print("\nLoading dataset statistics...")
    dataset_stats = load_dataset_stats(cfg.dataset_stats_path)

    print("Loading T5 text embeddings cache...")
    init_t5_text_embeddings_cache(cfg.t5_text_embeddings_path)

    print("Loading model...")
    # Find the .pt checkpoint file
    import glob
    ckpt_files = glob.glob(os.path.expanduser("~/.cache/huggingface/hub/models--nvidia--Cosmos-Policy-LIBERO-Predict2-2B/snapshots/*/Cosmos-Policy-LIBERO-Predict2-2B.pt"))
    ckpt_path = ckpt_files[0] if ckpt_files else cfg.ckpt_path
    print(f"  Checkpoint: {ckpt_path}")

    from cosmos_policy._src.predict2.utils.model_loader import load_model_from_checkpoint
    model, cosmos_config = load_model_from_checkpoint(
        experiment_name=cfg.config,
        s3_checkpoint_dir=ckpt_path,
        config_file=cfg.config_file,
        load_ema_to_reg=False,
        instantiate_ema=False,  # Don't create EMA copy - saves memory
    )
    model.eval()
    model = model.to("cuda")
    print("Model loaded successfully!")
    print(f"  GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    print(f"  GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

    # Generate autoregressive video rollouts
    run_video_rollout(cfg, model, dataset_stats, num_rollout_steps=20)

    print("\n" + "=" * 60)
    print("Done! Check outputs in:", OUTPUT_DIR)
    print("=" * 60)


if __name__ == "__main__":
    main()
