"""Visualize SVD+PARA: generated video + heatmap predictions side by side."""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, "/data/cameron/para_videopolicy")
sys.path.insert(0, "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import imageio
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont

from model_svd import SVDParaPredictor
import model_svd as model_module
from train_svd_para import (extract_pred_2d_and_height, PARA_OUT_SIZE,
                             N_HEIGHT_BINS, N_WINDOW, SVD_SIZE)
from data import CachedTrajectoryDataset
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel

CKPT = "/data/cameron/para_videopolicy/checkpoints/svd_para_ood_objpos/checkpoint_14000.pt"
SVD_BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
SVD_UNET = "output_libero_7f/checkpoint-46000/unet"
CACHE_ROOT = "/data/libero/ood_objpos_task0"
OUT_DIR = "vis_rollout_output"
IMAGE_SIZE = 448
DEVICE = "cuda"
NUM_VID_FRAMES = 7

os.makedirs(OUT_DIR, exist_ok=True)

def add_text(img_np, text, pos=(10, 20), color=(255, 255, 255), size=0.5):
    img = img_np.copy()
    cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, size, (0, 0, 0), 3)
    cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, size, color, 1)
    return img

def main():
    # Load checkpoint
    ckpt = torch.load(CKPT, map_location=DEVICE)
    stats = ckpt["stats"]
    model_module.MIN_HEIGHT = stats["min_height"]
    model_module.MAX_HEIGHT = stats["max_height"]
    model_module.MIN_GRIPPER = stats["min_gripper"]
    model_module.MAX_GRIPPER = stats["max_gripper"]
    model_module.MIN_ROT = stats["min_rot"]
    model_module.MAX_ROT = stats["max_rot"]

    # Build PARA policy model
    print("Loading SVD+PARA policy model...")
    model = SVDParaPredictor(svd_base=SVD_BASE, svd_unet=SVD_UNET, device=DEVICE).to(DEVICE)
    model.para_heads.load_state_dict(ckpt["para_heads_state_dict"])
    model.eval()

    # Build SVD video pipeline (reuse the same UNet)
    print("Loading SVD video pipeline...")
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        SVD_BASE, unet=model.svd_extractor.unet, torch_dtype=torch.float16, variant="fp16")
    pipe.to(DEVICE)

    coord_scale = PARA_OUT_SIZE / IMAGE_SIZE

    # Load demo trajectory
    demo_dir = Path(CACHE_ROOT) / "libero_spatial" / "task_0" / "demo_0"
    frame_files = sorted((demo_dir / "frames").glob("*.png"))
    pix_uv = np.load(demo_dir / "pix_uv.npy")
    print(f"Demo 0: {len(frame_files)} frames")

    stride = 3
    # Pick 5 evenly spaced starting frames
    max_start = len(frame_files) - NUM_VID_FRAMES
    start_indices = np.linspace(0, max(0, max_start - 1), 5, dtype=int)

    all_output_frames = []

    for si, start_idx in enumerate(start_indices):
        print(f"\n[{si+1}/5] Frame {start_idx}...")
        frame = np.array(Image.open(frame_files[start_idx]))
        frame_pil = Image.open(frame_files[start_idx]).resize((576, 320))

        # --- Generate SVD video ---
        print("  Generating video...")
        with torch.inference_mode():
            vid_frames = pipe(frame_pil, height=320, width=576,
                             num_frames=NUM_VID_FRAMES, decode_chunk_size=4,
                             num_inference_steps=25).frames[0]
        vid_frames_np = [np.array(f) for f in vid_frames]

        # --- Run PARA policy ---
        img_tensor = torch.from_numpy(frame).float().permute(2, 0, 1) / 255.0
        img_tensor = (img_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
                     torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_tensor = img_tensor.unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            vol_logits, _, _, feats = model(img_tensor)

        pred_2d, _ = extract_pred_2d_and_height(vol_logits, stats["min_height"], stats["max_height"])
        pred_2d = pred_2d[0].cpu().numpy()
        pred_2d_img = pred_2d / coord_scale

        # Heatmaps
        heatmaps = F.softmax(vol_logits[:1].reshape(1, N_WINDOW, -1), dim=2)
        heatmaps = heatmaps.view(1, N_WINDOW, N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
        heatmap_2d = heatmaps.max(dim=2)[0][0].cpu().numpy()  # (N_WINDOW, P, P)

        # GT pixels
        gt_pixels = []
        for t in range(N_WINDOW):
            fi = min(start_idx + t * stride, len(pix_uv) - 1)
            gt_pixels.append(pix_uv[fi])
        gt_pixels = np.array(gt_pixels)

        # --- Create side-by-side visualization ---
        # For each PARA timestep: [generated video frame | heatmap overlay]
        target_h = 320
        target_w = 576

        for t in range(min(N_WINDOW, NUM_VID_FRAMES)):
            # Left: SVD generated frame
            vid_frame = vid_frames_np[t] if t < len(vid_frames_np) else vid_frames_np[-1]
            left = add_text(vid_frame, f"SVD Generated t+{t}", (10, 25), (0, 255, 0))

            # Right: heatmap overlay on input frame
            heat = heatmap_2d[t] if t < N_WINDOW else heatmap_2d[-1]
            heat_norm = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
            heat_up = cv2.resize(heat_norm, (IMAGE_SIZE, IMAGE_SIZE))
            heat_color = cv2.applyColorMap((heat_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
            heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
            overlay = (frame * 0.5 + heat_color * 0.5).astype(np.uint8)

            # Draw GT and pred
            if t < N_WINDOW:
                gt_x, gt_y = int(gt_pixels[t, 0]), int(gt_pixels[t, 1])
                cv2.circle(overlay, (gt_x, gt_y), 8, (0, 255, 255), 2)
                pred_x, pred_y = int(pred_2d_img[t, 0]), int(pred_2d_img[t, 1])
                cv2.circle(overlay, (pred_x, pred_y), 8, (255, 0, 0), 2)

            # Resize overlay to match video frame size
            overlay_resized = cv2.resize(overlay, (target_w, target_h))
            overlay_resized = add_text(overlay_resized, f"PARA Heatmap t+{t} (cyan=GT red=PRED)",
                                       (10, 25), (255, 255, 0))

            # Combine side by side
            combined = np.concatenate([left, overlay_resized], axis=1)
            all_output_frames.append(combined)

        # Separator
        sep = np.zeros((target_h, target_w * 2, 3), dtype=np.uint8)
        sep = add_text(sep, f"--- Start frame {start_idx} ---",
                      (target_w // 2, target_h // 2), (255, 255, 255), 0.8)
        all_output_frames.append(sep)

        torch.cuda.empty_cache()

    # Save
    vid_path = f"{OUT_DIR}/video_and_heatmaps.mp4"
    imageio.mimwrite(vid_path, all_output_frames, fps=3, quality=8)
    print(f"\nSaved: {vid_path} ({len(all_output_frames)} frames)")

    # Save individual frames
    for i in [0, len(all_output_frames)//3, 2*len(all_output_frames)//3]:
        if i < len(all_output_frames):
            Image.fromarray(all_output_frames[i]).save(f"{OUT_DIR}/sidebyside_{i}.png")

    print(f"All outputs in {OUT_DIR}/")

if __name__ == "__main__":
    main()
