"""Evaluate fine-tuned SVD checkpoint on LIBERO validation images."""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import torch
import numpy as np
from PIL import Image
import imageio
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

CKPT_DIR = "output_libero_10vid/checkpoint-500"
VAL_DIR = "dataset/val_10"
OUT_DIR = "eval_output"
WIDTH, HEIGHT = 576, 320
NUM_FRAMES = 25

os.makedirs(OUT_DIR, exist_ok=True)

print("Loading fine-tuned model...")
unet = UNetSpatioTemporalConditionModel.from_pretrained(
    CKPT_DIR, subfolder="unet", torch_dtype=torch.float16
)
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
    unet=unet, torch_dtype=torch.float16, variant="fp16"
)
pipe.to("cuda")
print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

val_images = sorted([f for f in os.listdir(VAL_DIR) if f.endswith('.png')])
print(f"Evaluating on {len(val_images)} validation images...")

for i, img_name in enumerate(val_images[:5]):
    img_path = os.path.join(VAL_DIR, img_name)
    image = Image.open(img_path).resize((WIDTH, HEIGHT))
    print(f"\n[{i}] {img_name} -> {WIDTH}x{HEIGHT}")

    with torch.inference_mode():
        frames = pipe(
            image, height=HEIGHT, width=WIDTH,
            num_frames=NUM_FRAMES, decode_chunk_size=4,
            num_inference_steps=25,
        ).frames[0]

    frames_np = [np.array(f) for f in frames]
    mp4_path = os.path.join(OUT_DIR, f"eval_{i:02d}.mp4")
    imageio.mimwrite(mp4_path, frames_np, fps=8, quality=8)

    Image.fromarray(frames_np[0]).save(os.path.join(OUT_DIR, f"eval_{i:02d}_first.png"))
    Image.fromarray(frames_np[len(frames_np)//2]).save(os.path.join(OUT_DIR, f"eval_{i:02d}_mid.png"))
    Image.fromarray(frames_np[-1]).save(os.path.join(OUT_DIR, f"eval_{i:02d}_last.png"))
    print(f"  Saved {mp4_path}")
    torch.cuda.empty_cache()

print(f"\nDone! Videos at {OUT_DIR}/")
