"""Sanity check: load one sample from keygrip RealTrajectoryDataset, load actual 4 consecutive
video frames from that episode, draw trajectory keypoints on each frame, save to out/vis_traj_sanity.mp4.
  python scripts/test_dataloader.py
"""
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

KEYGRIP_VOLUME = REPO_ROOT.resolve().parents[1] / "keygrip" / "volume_dino_tracks"
sys.path.insert(0, str(KEYGRIP_VOLUME))

import numpy as np
import torch
import cv2

from data import RealTrajectoryDataset

N_FRAMES = 4
# keygrip/scratch/parsed_pickplace_exp1_feb9 under data/cameron (parents[2]=/data)
DEFAULT_DATASET_ROOT = REPO_ROOT.resolve().parents[2] / "cameron" / "keygrip" / "scratch" / "parsed_pickplace_exp1_feb9"
IMAGE_SIZE = 256


def load_frame(episode_dir: Path, frame_idx: int, image_size: int) -> np.ndarray:
    """Load one frame as (H, W, 3) uint8, resized to image_size (matches dataset)."""
    frame_str = f"{frame_idx:06d}"
    rgb_path = episode_dir / f"{frame_str}.png"
    if not rgb_path.exists():
        raise FileNotFoundError(rgb_path)
    rgb = cv2.imread(str(rgb_path))
    if rgb is None:
        raise RuntimeError(f"Failed to read {rgb_path}")
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    if rgb.shape[0] != image_size or rgb.shape[1] != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    return rgb


def main():
    dataset_root = Path(DEFAULT_DATASET_ROOT).resolve()
    print("Creating RealTrajectoryDataset...")
    dataset = RealTrajectoryDataset(dataset_root=str(dataset_root), image_size=IMAGE_SIZE)
    print(f"  {len(dataset)} samples")

    sample = dataset[0]
    episode_dir = Path(sample["episode_dir"])
    frame_idx = sample["frame_idx"]
    trajectory_2d = sample["trajectory_2d"]  # (N_WINDOW, 2)
    traj_2d = trajectory_2d[:N_FRAMES]  # (N_FRAMES, 2)

    # Load actual N_FRAMES consecutive video frames (same as dataset resolution)
    frames = []
    for i in range(N_FRAMES):
        frames.append(load_frame(episode_dir, frame_idx + i, IMAGE_SIZE))
    gt_np = np.stack(frames, axis=0)  # (T, H, W, C) uint8
    T, H, W, C = gt_np.shape

    for t in range(N_FRAMES):
        x = int(round(traj_2d[t, 0].item()))
        y = int(round(traj_2d[t, 1].item()))
        x = max(0, min(W - 1, x))
        y = max(0, min(H - 1, y))
        r = 6
        for dy in range(-r, r + 1):
            for dx in range(-r, r + 1):
                if dx * dx + dy * dy <= r * r:
                    ny, nx = y + dy, x + dx
                    if 0 <= ny < H and 0 <= nx < W:
                        gt_np[t, ny, nx] = [0, 255, 255]
    print(f"Loaded frames {frame_idx}..{frame_idx + N_FRAMES - 1} from {episode_dir.name}; drew keypoints on each frame.")

    out_dir = REPO_ROOT / "out"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "vis_traj_sanity.mp4"

    import torchvision
    torchvision.io.write_video(str(out_path), torch.from_numpy(gt_np), fps=4)
    print(f"Saved {out_path}")


if __name__ == "__main__":
    main()
