"""Check alignment: GT frames with GT keypoints (left) vs OOD SVD generated video (right).
No PARA predictions - just checking if GT keypoints and SVD video are temporally aligned."""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")

import torch
import numpy as np
import cv2
import imageio
from pathlib import Path
from PIL import Image

from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel
from data import CachedTrajectoryDataset

SVD_BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
SVD_UNET = "output_libero_ood_objpos/checkpoint-31500/unet"  # OOD objpos checkpoint
CACHE_ROOT = "/data/libero/ood_objpos_task0"
OUT_DIR = "vis_ood_alignment"
IMAGE_SIZE = 448
DEVICE = "cuda"
N_WINDOW = 4

os.makedirs(OUT_DIR, exist_ok=True)

def main():
    # Load OOD SVD model
    print(f"Loading SVD UNet from: {SVD_UNET}")
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        SVD_UNET, subfolder="unet" if "unet" not in SVD_UNET else None,
        torch_dtype=torch.float16)
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        SVD_BASE, unet=unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(DEVICE)

    # Load dataset
    dataset = CachedTrajectoryDataset(
        cache_root=CACHE_ROOT, benchmark_name="libero_spatial",
        task_ids=[0], image_size=IMAGE_SIZE, n_window=N_WINDOW, frame_stride=3,
    )
    print(f"Dataset: {len(dataset)} samples")

    all_frames = []
    sample_indices = [0, 200, 500, 1000, 2000]

    for si, idx in enumerate(sample_indices):
        print(f"\n[{si+1}/{len(sample_indices)}] Sample {idx}...")
        sample = dataset[idx]

        # GT data
        gt_frames = sample['rgb_frames_raw'].numpy()  # (N_WINDOW, 448, 448, 3) [0,1]
        gt_traj = sample['trajectory_2d'].numpy()       # (N_WINDOW, 2) in 448 space

        # First frame for SVD conditioning
        first_frame = (gt_frames[0] * 255).astype(np.uint8)
        first_pil = Image.fromarray(first_frame).resize((576, 320))

        # Generate SVD video
        with torch.inference_mode():
            vid_pil = pipe(first_pil, height=320, width=576,
                          num_frames=max(N_WINDOW + 1, 7), decode_chunk_size=4,
                          num_inference_steps=25).frames[0]
        vid_np = [np.array(f) for f in vid_pil]

        for t in range(N_WINDOW):
            # LEFT: GT frame with GT keypoint
            left = (gt_frames[t] * 255).astype(np.uint8).copy()
            gx, gy = int(gt_traj[t, 0]), int(gt_traj[t, 1])
            cv2.circle(left, (gx, gy), 10, (0, 255, 255), 3)
            cv2.circle(left, (gx, gy), 3, (0, 255, 255), -1)
            cv2.putText(left, f"GT t+{t*3} kp=({gx},{gy})", (5, 25),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

            # RIGHT: SVD generated frame (resize to 448x448 for comparison)
            svd_frame = vid_np[min(t, len(vid_np)-1)]
            right = cv2.resize(svd_frame, (448, 448))
            # Also draw GT keypoint on generated frame for reference
            cv2.circle(right, (gx, gy), 10, (0, 255, 255), 3)
            cv2.circle(right, (gx, gy), 3, (0, 255, 255), -1)
            cv2.putText(right, f"SVD Gen t+{t*3} (GT kp overlay)", (5, 25),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

            combined = np.concatenate([left, right], axis=1)
            all_frames.append(combined)

        # Separator
        sep = np.zeros((448, 896, 3), dtype=np.uint8)
        cv2.putText(sep, f"--- Sample {idx} ---", (300, 224),
                   cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
        all_frames.append(sep)
        torch.cuda.empty_cache()

    vid_path = f"{OUT_DIR}/ood_alignment.mp4"
    imageio.mimwrite(vid_path, all_frames, fps=2, quality=8)
    print(f"\nSaved: {vid_path} ({len(all_frames)} frames)")

    for i in range(min(8, len(all_frames))):
        Image.fromarray(all_frames[i]).save(f"{OUT_DIR}/frame_{i}.png")
    print(f"Saved PNGs to {OUT_DIR}/")

if __name__ == "__main__":
    main()
