"""
Minimal image-to-video inference for UWM (Unified World Model).
Loads a single start frame, runs the pretrained model to sample next observation (video)
and action, decodes the video, and saves the result.

Usage (from repo root with conda env uwm activated):
  export PYTHONPATH=/path/to/unified-world-model
  python experiments/uwm/run_image_to_video.py \
    --checkpoint path/to/models.pt \
    --image path/to/start_frame.png \
    --output out_video.mp4

If no --image is given, uses a random dummy image for testing.
Download DROID/LIBERO checkpoints from:
  https://drive.google.com/drive/folders/1M4AuVLMRpSwOf_YAp56bV9AqyZI9ul6g
"""

import argparse
import os
import sys
from pathlib import Path

import hydra
import numpy as np
import torch
import torchvision
from omegaconf import OmegaConf

# Repo root on path for configs and models
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Hydra config path
os.chdir(REPO_ROOT)


def build_obs_dict_from_image(image_path: str, device: torch.device, num_frames: int = 2):
    """
    Build curr_obs_dict from a single image file.
    Uses the same frame for all 3 DROID views and repeats for obs history (num_frames).
    Image is resized to 180x320 (DROID shape_meta) and returned as (1, T, H, W, C) uint8.
    """
    from PIL import Image
    target_h, target_w = 180, 320
    rgb_keys = ["exterior_image_1_left", "exterior_image_2_left", "wrist_image_left"]

    if image_path and Path(image_path).exists():
        img = Image.open(image_path).convert("RGB")
        img = np.array(img)
    else:
        # Dummy image for testing
        img = np.random.randint(0, 255, (target_h, target_w, 3), dtype=np.uint8)

    # Resize to DROID resolution
    if img.shape[0] != target_h or img.shape[1] != target_w:
        img = np.array(
            Image.fromarray(img).resize((target_w, target_h), Image.BILINEAR)
        )
    # (H, W, C) -> (1, T, H, W, C)
    frame = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
    frame = frame.expand(1, num_frames, target_h, target_w, 3).contiguous()

    obs_dict = {}
    for key in rgb_keys:
        obs_dict[key] = frame.to(device)
    return obs_dict


def run_inference(model, obs_dict, device):
    """Run sample_joint and decode next_obs to RGB video."""
    model.eval()
    with torch.no_grad():
        next_obs_latent, action = model.sample_joint(obs_dict)
        # next_obs_latent: (B, V, C, T, H, W) in latent space
        decoded = model.obs_encoder.apply_vae(next_obs_latent, inverse=True)
    return decoded, action


def save_video(tensor_video: torch.Tensor, path: str, fps: int = 4):
    """
    tensor_video: (B, V, C, T, H, W) in [0, 1].
    Save as MP4 by flattening views/frames to a grid or first view only.
    """
    # Use first batch, first view: (C, T, H, W)
    v = tensor_video[0, 0]
    # (C, T, H, W) -> (T, H, W, C) for write_video
    frames = v.permute(1, 2, 3, 0).cpu().clamp(0, 1).numpy()
    frames = (frames * 255).astype(np.uint8)
    path_obj = Path(path)
    try:
        torchvision.io.write_video(path, torch.from_numpy(frames), fps=fps)
    except Exception:
        try:
            import imageio
            p = path_obj.with_suffix(".mp4") if path_obj.suffix.lower() != ".mp4" else path_obj
            with imageio.get_writer(str(p), format="FFMPEG", fps=fps) as w:
                for f in frames:
                    w.append_data(f)
        except Exception:
            out_dir = path_obj.parent / (path_obj.stem + "_frames")
            out_dir.mkdir(parents=True, exist_ok=True)
            from PIL import Image
            for i, f in enumerate(frames):
                Image.fromarray(f).save(out_dir / f"frame_{i:03d}.png")
            print(f"Video write failed; saved {len(frames)} frames to {out_dir}/")


def main():
    parser = argparse.ArgumentParser(description="UWM image-to-video inference")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to models.pt")
    parser.add_argument("--image", type=str, default=None, help="Start frame image (optional)")
    parser.add_argument("--output", type=str, default="uwm_out.mp4", help="Output video path")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # Load config via Hydra (use inference config to avoid hydra runtime interpolations)
    with hydra.initialize_config_dir(config_dir=str(REPO_ROOT / "configs"), version_base=None):
        config = hydra.compose(config_name="inference_uwm.yaml")
    OmegaConf.resolve(config)
    from hydra.utils import instantiate
    model = instantiate(config.model)
    model = model.to(device)

    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"], strict=True)
    print(f"Loaded checkpoint from {ckpt_path}, step {ckpt.get('step', '?')}")

    num_frames = getattr(model.obs_encoder, "num_frames", 2)
    obs_dict = build_obs_dict_from_image(args.image, device, num_frames=num_frames)
    if args.image:
        print(f"Using start frame: {args.image}")
    else:
        print("Using random dummy image")

    decoded, action = run_inference(model, obs_dict, device)
    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    save_video(decoded, str(out_path))
    print(f"Saved video to {out_path}")
    print(f"Sampled action shape: {action.shape}")

    return 0


if __name__ == "__main__":
    sys.exit(main())
