"""Alignment check: plot GT video vs generated video with PARA keypoints on both.
Left = GT frames with GT keypoints (cyan) + predicted keypoints (red)
Right = SVD generated frames with predicted keypoints (red)
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import imageio
from pathlib import Path
from PIL import Image

from model_svd import SVDParaPredictor
import model_svd as model_module
from train_svd_para import (extract_pred_2d_and_height, PARA_OUT_SIZE,
                             N_HEIGHT_BINS, N_WINDOW, SVD_SIZE)
from data import CachedTrajectoryDataset
from svd.pipelines import StableVideoDiffusionPipeline

CKPT = "/data/cameron/para_videopolicy/checkpoints/svd_para_ood_objpos/checkpoint_14000.pt"
SVD_BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
SVD_UNET = "output_libero_7f/checkpoint-46000/unet"
CACHE_ROOT = "/data/libero/ood_objpos_task0"
OUT_DIR = "vis_alignment"
IMAGE_SIZE = 448
DEVICE = "cuda"

os.makedirs(OUT_DIR, exist_ok=True)

def draw_keypoint(img, x, y, color, label="", radius=8):
    """Draw a keypoint with crosshair."""
    x, y = int(x), int(y)
    cv2.circle(img, (x, y), radius, color, 2)
    cv2.circle(img, (x, y), 2, color, -1)
    cv2.line(img, (x-radius-3, y), (x+radius+3, y), color, 1)
    cv2.line(img, (x, y-radius-3), (x, y+radius+3), color, 1)
    if label:
        cv2.putText(img, label, (x+radius+2, y-2), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
    return img

def main():
    # Load checkpoint + stats
    ckpt = torch.load(CKPT, map_location=DEVICE)
    stats = ckpt["stats"]
    model_module.MIN_HEIGHT = stats["min_height"]
    model_module.MAX_HEIGHT = stats["max_height"]
    model_module.MIN_GRIPPER = stats["min_gripper"]
    model_module.MAX_GRIPPER = stats["max_gripper"]
    model_module.MIN_ROT = stats["min_rot"]
    model_module.MAX_ROT = stats["max_rot"]

    # Build models
    print("Loading SVD+PARA model...")
    model = SVDParaPredictor(svd_base=SVD_BASE, svd_unet=SVD_UNET, device=DEVICE).to(DEVICE)
    model.para_heads.load_state_dict(ckpt["para_heads_state_dict"])
    model.eval()

    print("Loading SVD pipeline...")
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        SVD_BASE, unet=model.svd_extractor.unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(DEVICE)

    # Load dataset (same as training)
    dataset = CachedTrajectoryDataset(
        cache_root=CACHE_ROOT, benchmark_name="libero_spatial",
        task_ids=[0], image_size=IMAGE_SIZE, n_window=N_WINDOW, frame_stride=3,
    )
    print(f"Dataset: {len(dataset)} samples")

    coord_scale = PARA_OUT_SIZE / IMAGE_SIZE

    # Pick a few samples
    sample_indices = [0, 100, 500, 1000, 2000]
    all_frames = []

    for si, idx in enumerate(sample_indices):
        print(f"\n[{si+1}/{len(sample_indices)}] Sample {idx}...")
        sample = dataset[idx]

        # GT data
        rgb_tensor = sample['rgb']  # (3, 448, 448) ImageNet-normalized
        gt_traj_2d = sample['trajectory_2d'].numpy()  # (N_WINDOW, 2) in 448 space
        gt_frames_raw = sample['rgb_frames_raw'].numpy()  # (N_WINDOW, 448, 448, 3) in [0,1]

        # Undo ImageNet normalization for display
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        rgb_display = rgb_tensor.numpy().transpose(1, 2, 0) * std + mean
        rgb_display = (np.clip(rgb_display, 0, 1) * 255).astype(np.uint8)

        # Run PARA model
        with torch.no_grad():
            vol_logits, _, _, _ = model(rgb_tensor.unsqueeze(0).to(DEVICE))
        pred_2d, _ = extract_pred_2d_and_height(vol_logits, stats["min_height"], stats["max_height"])
        pred_2d_para = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2) in PARA_OUT_SIZE coords
        pred_2d_img = pred_2d_para / coord_scale  # scale to 448 space

        # Generate SVD video from first frame
        first_frame_pil = Image.fromarray(rgb_display).resize((576, 320))
        with torch.inference_mode():
            vid_frames = pipe(first_frame_pil, height=320, width=576,
                             num_frames=max(N_WINDOW + 1, 7), decode_chunk_size=4,
                             num_inference_steps=25).frames[0]
        vid_frames_np = [np.array(f) for f in vid_frames]

        # Create visualization for each timestep
        for t in range(N_WINDOW):
            # --- LEFT: GT frame with GT + predicted keypoints ---
            gt_frame = (gt_frames_raw[t] * 255).astype(np.uint8)
            left = gt_frame.copy()
            # GT keypoint (cyan)
            gt_x, gt_y = gt_traj_2d[t, 0], gt_traj_2d[t, 1]
            draw_keypoint(left, gt_x, gt_y, (0, 255, 255), f"GT")
            # Predicted keypoint (red)
            pred_x, pred_y = pred_2d_img[t, 0], pred_2d_img[t, 1]
            draw_keypoint(left, pred_x, pred_y, (255, 0, 0), f"PRED")
            # Label
            cv2.putText(left, f"GT Frame t+{t} (stride=3)", (5, 20),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
            cv2.putText(left, f"Sample {idx}", (5, 40),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)

            # --- RIGHT: SVD generated frame with predicted keypoint ---
            svd_idx = min(t, len(vid_frames_np) - 1)
            svd_frame = vid_frames_np[svd_idx]
            # Resize SVD frame to 448x448 for comparison
            right = cv2.resize(svd_frame, (448, 448))
            # Draw predicted keypoint (red)
            draw_keypoint(right, pred_x, pred_y, (255, 0, 0), f"PRED")
            cv2.putText(right, f"SVD Generated t+{t}", (5, 20),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

            # Combine
            combined = np.concatenate([left, right], axis=1)
            all_frames.append(combined)

        # Separator
        sep = np.zeros((448, 448 * 2, 3), dtype=np.uint8)
        cv2.putText(sep, f"--- Sample {idx} ---", (300, 224),
                   cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
        all_frames.append(sep)

        torch.cuda.empty_cache()

    # Save
    vid_path = f"{OUT_DIR}/alignment_check.mp4"
    imageio.mimwrite(vid_path, all_frames, fps=2, quality=8)
    print(f"\nSaved: {vid_path} ({len(all_frames)} frames)")

    # Save individual frames for quick viewing
    for i in range(min(5, len(all_frames))):
        Image.fromarray(all_frames[i]).save(f"{OUT_DIR}/sample_{i}.png")
    print(f"Saved sample PNGs to {OUT_DIR}/")

if __name__ == "__main__":
    main()
