"""Test base SVD model on a LIBERO start frame - verify plausible output before training."""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import torch
import numpy as np
from PIL import Image
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

print("Loading SVD pipeline...")
unet = UNetSpatioTemporalConditionModel.from_pretrained(
    "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
    subfolder="unet", torch_dtype=torch.float16, variant="fp16"
)
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
    unet=unet, torch_dtype=torch.float16, variant="fp16"
)
pipe.to("cuda")
print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# Load LIBERO validation image
img_path = "dataset/validation_images/libero_0000.png"
image = Image.open(img_path).resize((512, 512))
print(f"Input image: {img_path} -> {image.size}")

# Generate video
print("Generating video (25 frames)...")
with torch.inference_mode():
    frames = pipe(image, num_frames=25, decode_chunk_size=4, num_inference_steps=25).frames[0]

print(f"Generated {len(frames)} frames")

# Save frames for inspection
out_dir = "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/test_output"
os.makedirs(out_dir, exist_ok=True)

frames[0].save(f"{out_dir}/base_svd_first.png")
frames[len(frames)//2].save(f"{out_dir}/base_svd_middle.png")
frames[-1].save(f"{out_dir}/base_svd_last.png")

# Save as mp4
import imageio
frames_np = [np.array(f) for f in frames]
imageio.mimwrite(f"{out_dir}/base_svd_libero.mp4", frames_np, fps=8, quality=8)
print(f"Saved to {out_dir}/")
