"""Visualize SVD+PARA policy predictions on dataset trajectories.
Shows heatmap overlays and predicted vs GT pixel locations as a video."""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import sys
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/vidgen/svd_motion_lora/Motion-LoRA")

import json
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import imageio
from pathlib import Path
from PIL import Image

from model_svd import SVDParaPredictor
import model_svd as model_module
from train_svd_para import (extract_pred_2d_and_height, PARA_OUT_SIZE,
                             N_HEIGHT_BINS, N_WINDOW, SVD_SIZE)
from data import CachedTrajectoryDataset

CKPT = "/data/cameron/para_videopolicy/checkpoints/svd_para_ood_objpos/checkpoint_14000.pt"
CACHE_ROOT = "/data/libero/ood_objpos_task0"
OUT_DIR = "vis_rollout_output"
IMAGE_SIZE = 448
DEVICE = "cuda"

os.makedirs(OUT_DIR, exist_ok=True)

def main():
    # Load checkpoint
    ckpt = torch.load(CKPT, map_location=DEVICE)
    stats = ckpt["stats"]
    model_module.MIN_HEIGHT = stats["min_height"]
    model_module.MAX_HEIGHT = stats["max_height"]
    model_module.MIN_GRIPPER = stats["min_gripper"]
    model_module.MAX_GRIPPER = stats["max_gripper"]
    model_module.MIN_ROT = stats["min_rot"]
    model_module.MAX_ROT = stats["max_rot"]

    # Build model
    print("Loading SVD+PARA model...")
    model = SVDParaPredictor(
        svd_base="/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/checkpoints/stable-video-diffusion-img2vid-xt-1-1",
        svd_unet="/data/cameron/vidgen/svd_motion_lora/Motion-LoRA/output_libero_7f/checkpoint-46000/unet",
        device=DEVICE,
    ).to(DEVICE)
    model.para_heads.load_state_dict(ckpt["para_heads_state_dict"])
    model.eval()
    print(f"  Loaded checkpoint: {CKPT}")

    # Load dataset
    dataset = CachedTrajectoryDataset(
        cache_root=CACHE_ROOT, benchmark_name="libero_spatial",
        task_ids=[0], image_size=IMAGE_SIZE, n_window=N_WINDOW, frame_stride=3,
    )
    print(f"  Dataset: {len(dataset)} samples")

    coord_scale = PARA_OUT_SIZE / IMAGE_SIZE

    # Visualize a few trajectories
    # Pick demo_0 and step through it
    demo_dir = Path(CACHE_ROOT) / "libero_spatial" / "task_0" / "demo_0"
    frame_files = sorted((demo_dir / "frames").glob("*.png"))
    eef_pos = np.load(demo_dir / "eef_pos.npy")
    pix_uv = np.load(demo_dir / "pix_uv.npy")

    print(f"  Demo 0: {len(frame_files)} frames")

    all_vis_frames = []

    # Step through frames with stride 3 (matching training)
    stride = 3
    for start_idx in range(0, len(frame_files) - N_WINDOW * stride, stride * 2):
        # Load and preprocess frame
        frame = np.array(Image.open(frame_files[start_idx]))
        frame_rgb = frame.copy()

        # ImageNet normalize
        img_tensor = torch.from_numpy(frame).float().permute(2, 0, 1) / 255.0
        img_tensor = (img_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
                     torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_tensor = img_tensor.unsqueeze(0).to(DEVICE)

        # Get GT trajectory
        gt_pixels = []
        for t in range(N_WINDOW):
            fi = min(start_idx + t * stride, len(pix_uv) - 1)
            gt_pixels.append(pix_uv[fi])
        gt_pixels = np.array(gt_pixels)  # (N_WINDOW, 2)

        # Model prediction
        with torch.no_grad():
            vol_logits, grip_logits, rot_logits, feats = model(img_tensor)

        # Extract predicted 2D + height
        pred_2d, pred_height = extract_pred_2d_and_height(
            vol_logits, stats["min_height"], stats["max_height"]
        )
        pred_2d = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2) in PARA_OUT_SIZE coords

        # Scale back to image coords
        pred_2d_img = pred_2d / coord_scale

        # Generate heatmap overlay
        heatmaps = F.softmax(vol_logits[:1].reshape(1, N_WINDOW, -1), dim=2)
        heatmaps = heatmaps.view(1, N_WINDOW, N_HEIGHT_BINS, PARA_OUT_SIZE, PARA_OUT_SIZE)
        heatmap_2d = heatmaps.max(dim=2)[0][0]  # (N_WINDOW, P, P)

        # Create visualization: 2 rows x N_WINDOW cols
        # Top: input frame with heatmap overlay per timestep
        # Bottom: input frame with GT (cyan) and pred (red) markers
        vis_w = IMAGE_SIZE
        vis_h = IMAGE_SIZE

        for t in range(N_WINDOW):
            # Heatmap overlay
            heat = heatmap_2d[t].cpu().numpy()
            heat_norm = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
            heat_up = cv2.resize(heat_norm, (vis_w, vis_h), interpolation=cv2.INTER_LINEAR)
            heat_color = cv2.applyColorMap((heat_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
            heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)

            overlay = (frame_rgb * 0.5 + heat_color * 0.5).astype(np.uint8)

            # Draw GT pixel (cyan circle)
            gt_x, gt_y = int(gt_pixels[t, 0]), int(gt_pixels[t, 1])
            cv2.circle(overlay, (gt_x, gt_y), 8, (0, 255, 255), 2)
            cv2.circle(overlay, (gt_x, gt_y), 2, (0, 255, 255), -1)

            # Draw predicted pixel (red circle)
            pred_x, pred_y = int(pred_2d_img[t, 0]), int(pred_2d_img[t, 1])
            cv2.circle(overlay, (pred_x, pred_y), 8, (255, 0, 0), 2)
            cv2.circle(overlay, (pred_x, pred_y), 2, (255, 0, 0), -1)

            # Add text
            cv2.putText(overlay, f"t+{t*stride} GT=cyan PRED=red", (10, 25),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
            cv2.putText(overlay, f"frame {start_idx}", (10, 50),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1)

            all_vis_frames.append(overlay)

        # Add separator
        sep = np.zeros((vis_h, vis_w, 3), dtype=np.uint8)
        cv2.putText(sep, f"--- Frame {start_idx} ---", (vis_w//4, vis_h//2),
                   cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
        all_vis_frames.append(sep)

    # Save video
    vid_path = f"{OUT_DIR}/rollout_heatmaps.mp4"
    imageio.mimwrite(vid_path, all_vis_frames, fps=4, quality=8)
    print(f"\nSaved rollout video: {vid_path} ({len(all_vis_frames)} frames)")

    # Also save a few individual frames
    for i, idx in enumerate([0, len(all_vis_frames)//3, 2*len(all_vis_frames)//3]):
        if idx < len(all_vis_frames):
            Image.fromarray(all_vis_frames[idx]).save(f"{OUT_DIR}/frame_{i}.png")

    print(f"All outputs in {OUT_DIR}/")


if __name__ == "__main__":
    main()
