"""Explore SVD video generation speed: denoising steps × quantization.

Takes GT trajectories, picks start and middle conditioning frames,
generates videos under different configs, saves with timing in filename.

Configs (2 frames × 5 settings = 10 videos):
- Normal: fp16, 25 steps (baseline)
- Quant fp16 half-steps: fp16, 13 steps
- Quant fp16 quarter-steps: fp16, 6 steps
- Quant int8 half-steps: int8, 13 steps
- Quant int8 quarter-steps: int8, 6 steps
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, os.path.dirname(__file__))

import time
import numpy as np
import torch
import cv2
import imageio
import subprocess
from pathlib import Path
from PIL import Image

from svd.models import UNetSpatioTemporalConditionModel
from svd.pipelines import StableVideoDiffusionPipeline
from diffusers import AutoencoderKLTemporalDecoder
from transformers import CLIPVisionModelWithProjection

SVD_H, SVD_W = 320, 576
N_WINDOW = 7
DEVICE = "cuda"
CKPT_DIR = "output_svd_v3_stage2_joint/checkpoint-3000"
SVD_BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
OUT_DIR = "explore_svd_speed"


def load_gt_trajectory(cache_root="/data/libero/ood_objpos_v3", task_id=0, demo_id=0):
    """Load GT frames from a cached trajectory."""
    demo_dir = Path(cache_root) / "libero_spatial" / f"task_{task_id}" / f"demo_{demo_id}" / "frames"
    frame_files = sorted(demo_dir.glob("*.png"))
    frames = [np.array(Image.open(f)) for f in frame_files]
    return frames


def save_video(frames_pil, path, fps=7):
    """Save list of PIL images as H.264 mp4."""
    tmp = str(path).replace('.mp4', '_raw.mp4')
    frames_np = [np.array(f) for f in frames_pil]
    imageio.mimwrite(tmp, frames_np, fps=fps, quality=8)
    subprocess.run(["ffmpeg", "-y", "-i", tmp, "-c:v", "libx264",
                    "-preset", "ultrafast", "-crf", "23",
                    "-movflags", "+faststart", str(path)],
                   capture_output=True)
    os.remove(tmp)


def save_comparison_video(gt_frames, gen_frames_pil, path, label="", fps=7):
    """Save side-by-side GT vs generated video."""
    n = min(len(gt_frames), len(gen_frames_pil), N_WINDOW)
    combined = []
    for t in range(n):
        gt = cv2.resize(gt_frames[t], (448, 448))
        gen = cv2.resize(np.array(gen_frames_pil[t]), (448, 448))
        # Label
        cv2.putText(gt, "GT", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        cv2.putText(gen, f"Gen {label}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        combined.append(np.concatenate([gt, gen], axis=1))

    tmp = str(path).replace('.mp4', '_raw.mp4')
    imageio.mimwrite(tmp, combined, fps=fps, quality=8)
    subprocess.run(["ffmpeg", "-y", "-i", tmp, "-c:v", "libx264",
                    "-preset", "ultrafast", "-crf", "23",
                    "-movflags", "+faststart", str(path)],
                   capture_output=True)
    os.remove(tmp)


def main():
    device = torch.device(DEVICE)
    os.makedirs(OUT_DIR, exist_ok=True)

    # Load GT trajectory
    print("Loading GT trajectory...")
    gt_frames = load_gt_trajectory()
    n_frames = len(gt_frames)
    print(f"  {n_frames} GT frames, shape: {gt_frames[0].shape}")

    # Conditioning frames: start (frame 0) and middle (frame n//2)
    cond_indices = {"start": 0, "middle": n_frames // 2}

    # Configs: (name, dtype, num_inference_steps)
    configs = [
        ("baseline_fp16_25steps", torch.float16, 25),
        ("fp16_13steps",          torch.float16, 13),
        ("fp16_6steps",           torch.float16, 6),
        ("fp16_3steps",           torch.float16, 3),
        ("fp16_1step",            torch.float16, 1),
    ]

    results = []

    for config_name, dtype, n_steps in configs:
        print(f"\n=== Config: {config_name} ===")
        t_load_start = time.time()

        unet = UNetSpatioTemporalConditionModel.from_pretrained(
            f"{CKPT_DIR}/unet", torch_dtype=dtype).to(device)
        pipe_dtype = dtype

        unet.eval()
        pipe = StableVideoDiffusionPipeline.from_pretrained(
            SVD_BASE, unet=unet, torch_dtype=pipe_dtype, variant="fp16")
        pipe.to(device)
        t_load = time.time() - t_load_start
        print(f"  Model loaded in {t_load:.1f}s")

        for frame_name, frame_idx in cond_indices.items():
            # Get conditioning frame and GT window
            cond_frame = gt_frames[frame_idx]
            gt_window = gt_frames[frame_idx:frame_idx + N_WINDOW]

            cond_pil = Image.fromarray(cond_frame).resize((SVD_W, SVD_H))

            # Generate video
            torch.cuda.synchronize()
            t_start = time.time()
            with torch.inference_mode():
                gen_pil = pipe(cond_pil, height=SVD_H, width=SVD_W,
                              num_frames=N_WINDOW, decode_chunk_size=4,
                              num_inference_steps=n_steps).frames[0]
            torch.cuda.synchronize()
            t_gen = time.time() - t_start

            # Save comparison video
            t_ms = int(t_gen * 1000)
            fname = f"{frame_name}_{config_name}_{t_ms}ms.mp4"
            out_path = Path(OUT_DIR) / fname
            save_comparison_video(gt_window, gen_pil, out_path,
                                label=f"{config_name} ({t_gen:.2f}s)")

            print(f"  {frame_name}: {t_gen:.2f}s → {fname}")
            results.append({
                "frame": frame_name,
                "config": config_name,
                "dtype": str(pipe_dtype),
                "n_steps": n_steps,
                "gen_time_s": round(t_gen, 3),
                "filename": fname,
            })

        # Cleanup
        del pipe, unet
        torch.cuda.empty_cache()

    # Print summary
    print(f"\n{'='*60}")
    print(f"{'Config':<25} {'Frame':<8} {'Steps':<6} {'Time (s)':<10}")
    print(f"{'-'*60}")
    for r in results:
        print(f"{r['config']:<25} {r['frame']:<8} {r['n_steps']:<6} {r['gen_time_s']:<10.3f}")

    # Save results
    import json
    with open(f"{OUT_DIR}/results.json", "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nAll videos saved to {OUT_DIR}/")


if __name__ == "__main__":
    main()
