"""Compare base SVD video generation at native vs training resolution on LIBERO."""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import torch
import numpy as np
from PIL import Image
import imageio
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

print("Loading SVD pipeline...")
unet = UNetSpatioTemporalConditionModel.from_pretrained(
    "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
    subfolder="unet", torch_dtype=torch.float16, variant="fp16"
)
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "checkpoints/stable-video-diffusion-img2vid-xt-1-1",
    unet=unet, torch_dtype=torch.float16, variant="fp16"
)
pipe.to("cuda")
pipe.enable_model_cpu_offload()  # save memory for native res
print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

img_path = "dataset/val_single/frame0.png"
image = Image.open(img_path)
print(f"Input image: {image.size}")

out_dir = "test_output"
os.makedirs(out_dir, exist_ok=True)

configs = [
    ("native_1024x576", 1024, 576, 25),
    ("half_576x320", 576, 320, 25),
]

for name, w, h, nf in configs:
    print(f"\n--- Generating {name} ({w}x{h}, {nf} frames) ---")
    img_resized = image.resize((w, h))

    with torch.inference_mode():
        frames = pipe(
            img_resized,
            height=h, width=w,
            num_frames=nf,
            decode_chunk_size=4,
            num_inference_steps=25,
        ).frames[0]

    frames_np = [np.array(f) for f in frames]
    imageio.mimwrite(f"{out_dir}/{name}.mp4", frames_np, fps=8, quality=8)
    Image.fromarray(frames_np[0]).save(f"{out_dir}/{name}_first.png")
    Image.fromarray(frames_np[len(frames_np)//2]).save(f"{out_dir}/{name}_middle.png")
    Image.fromarray(frames_np[-1]).save(f"{out_dir}/{name}_last.png")
    print(f"  Saved {len(frames_np)} frames to {out_dir}/{name}.mp4")
    torch.cuda.empty_cache()

print("\nDone!")
