"""Visualize real dense trajectory dataset ground truth.

For one episode:
  1. Plot the 2D trajectory on the image (first window).
  2. After that window is closed, open a 3D matplotlib plot of the pixel-aligned volume:
     - For each pixel (subsampled), unproject into 3D using N height buckets (pixel + height + camera).
     - All non-target 3D points: almost transparent white.
     - Target 3D trajectory points: opaque red.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import argparse

from data import N_WINDOW

# Virtual gripper keypoint in local gripper frame
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9],
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Number of height buckets for volume visualization (matches model volume idea)
N_HEIGHT_BINS = 32


def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point to 2D pixel coordinates."""
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = camera_pose @ point_3d_h
    if point_cam[2] <= 0:
        return None
    point_2d_h = cam_K @ point_cam[:3]
    point_2d = point_2d_h[:2] / point_2d_h[2]
    return point_2d


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D point from 2D pixel and height (ray-plane intersection at z=height)."""
    cam_pose_inv = np.linalg.inv(camera_pose)
    cam_pos = cam_pose_inv[:3, 3]
    cam_rot_c2w = cam_pose_inv[:3, :3]
    fx, fy = cam_K[0, 0], cam_K[1, 1]
    cx, cy = cam_K[0, 2], cam_K[1, 2]
    x_cam = (kp_2d_image[0] - cx) / fx
    y_cam = (kp_2d_image[1] - cy) / fy
    z_cam = 1.0
    ray_cam = np.array([x_cam, y_cam, z_cam])
    ray_cam = ray_cam / np.linalg.norm(ray_cam)
    ray_world = cam_rot_c2w @ ray_cam
    if abs(ray_world[2]) < 1e-6:
        return None
    t = (height - cam_pos[2]) / ray_world[2]
    if t < 0:
        return None
    return cam_pos + t * ray_world


def load_episode_trajectory(episode_dir, n_window):
    """Load first-frame image, trajectory_2d, trajectory_3d, camera_pose, cam_K for one episode."""
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) == 0:
        return None

    start_img = plt.imread(frame_files[0])
    H, W = start_img.shape[:2]

    trajectory_2d = []
    trajectory_3d = []

    num_frames = min(len(frame_files), n_window)
    for frame_idx in range(num_frames):
        frame_str = frame_files[frame_idx].stem
        gripper_pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if not gripper_pose_path.exists():
            break
        gripper_pose = np.load(gripper_pose_path)
        gripper_rot = gripper_pose[:3, :3]
        gripper_pos = gripper_pose[:3, 3]
        kp_3d = gripper_rot @ kp_local + gripper_pos
        trajectory_3d.append(kp_3d)

        cam_K_norm_path = episode_dir / f"{frame_str}_cam_K.npy"
        camera_pose_path = episode_dir / f"{frame_str}_camera_pose.npy"
        if not (cam_K_norm_path.exists() and camera_pose_path.exists()):
            break
        cam_K_norm = np.load(cam_K_norm_path)
        camera_pose = np.load(camera_pose_path)
        cam_K = cam_K_norm.copy()
        cam_K[0] *= W
        cam_K[1] *= H

        kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
        if kp_2d is None:
            break
        trajectory_2d.append(kp_2d)

    if len(trajectory_2d) == 0:
        return None

    trajectory_2d = np.array(trajectory_2d)
    trajectory_3d = np.array(trajectory_3d)

    # Use first frame camera for volume (same as 2D overlay)
    frame_str = frame_files[0].stem
    cam_K_norm = np.load(episode_dir / f"{frame_str}_cam_K.npy")
    camera_pose = np.load(episode_dir / f"{frame_str}_camera_pose.npy")
    cam_K = cam_K_norm.copy()
    cam_K[0] *= W
    cam_K[1] *= H

    # Pad to n_window if needed
    if len(trajectory_2d) < n_window:
        n_pad = n_window - len(trajectory_2d)
        trajectory_2d = np.concatenate([trajectory_2d, np.tile(trajectory_2d[-1:], (n_pad, 1))], axis=0)
        trajectory_3d = np.concatenate([trajectory_3d, np.tile(trajectory_3d[-1:], (n_pad, 1))], axis=0)
    else:
        trajectory_2d = trajectory_2d[:n_window]
        trajectory_3d = trajectory_3d[:n_window]

    return {
        'episode_id': episode_dir.name,
        'start_img': start_img,
        'H': H,
        'W': W,
        'trajectory_2d': trajectory_2d,
        'trajectory_3d': trajectory_3d,
        'camera_pose': camera_pose,
        'cam_K': cam_K,
    }


def build_volume_3d_points(H, W, camera_pose, cam_K, height_bucket_centers, pixel_step=16):
    """Unproject each (pixel, height_bucket) to 3D. Returns (N, 3) array of valid points."""
    points_3d = []
    for y in range(0, H, pixel_step):
        for x in range(0, W, pixel_step):
            pixel_2d = np.array([x, y], dtype=np.float64)
            for height in height_bucket_centers:
                pt = recover_3d_from_direct_keypoint_and_height(
                    pixel_2d, float(height), camera_pose, cam_K
                )
                if pt is not None:
                    points_3d.append(pt)
    return np.array(points_3d) if points_3d else np.zeros((0, 3))


def main():
    parser = argparse.ArgumentParser(description='Visualize GT: 2D trajectory then 3D volume for one episode')
    parser.add_argument('--dataset_dir', '-d', default='scratch/parsed_school_cap', type=str,
                        help='Root dataset directory')
    parser.add_argument('--episode_index', type=int, default=0, help='Episode index (0-based)')
    parser.add_argument('--pixel_step', type=int, default=16,
                        help='Subsample pixels for volume (larger = fewer points)')
    parser.add_argument('--height_margin', type=float, default=0.01,
                        help='Margin (m) around trajectory z range for volume height buckets')
    args = parser.parse_args()

    dataset_root = Path(args.dataset_dir)
    if not dataset_root.exists():
        print(f'Dataset not found: {dataset_root}')
        return

    episode_dirs = sorted([d for d in dataset_root.iterdir() if d.is_dir() and 'episode' in d.name])
    if len(episode_dirs) == 0:
        print(f'No episodes in {dataset_root}')
        return

    for idx in range(len(episode_dirs)):

        #idx = min(args.episode_index, len(episode_dirs) - 1)
        episode_dir = episode_dirs[idx]
        data = load_episode_trajectory(episode_dir, N_WINDOW)
        if data is None:
            print(f'Could not load episode {episode_dir.name}')
            return

        H, W = data['H'], data['W']
        trajectory_2d = data['trajectory_2d']
        trajectory_3d = data['trajectory_3d']
        start_img = data['start_img']

        # ----- 1) First window: 2D trajectory on image -----
        fig1, ax = plt.subplots(1, 1, figsize=(10, 8))
        ax.imshow(start_img)
        trajectory_2d_clipped = np.array([
            [np.clip(trajectory_2d[t, 0], 0, W - 1), np.clip(trajectory_2d[t, 1], 0, H - 1)]
            for t in range(N_WINDOW)
        ])
        colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
        for t in range(N_WINDOW - 1):
            ax.plot(
                [trajectory_2d_clipped[t, 0], trajectory_2d_clipped[t + 1, 0]],
                [trajectory_2d_clipped[t, 1], trajectory_2d_clipped[t + 1, 1]],
                '-', color=colors[t], linewidth=2, alpha=0.7
            )
        for t in range(N_WINDOW):
            ax.plot(
                trajectory_2d_clipped[t, 0], trajectory_2d_clipped[t, 1],
                'o', color=colors[t], markersize=6, markeredgecolor='white', markeredgewidth=1.5, alpha=0.9
            )
        ax.set_title(f"{data['episode_id']} — GT 2D trajectory ({N_WINDOW} waypoints)")
        ax.axis('off')
        plt.tight_layout()
        plt.show()  # Block until user closes this window
        continue

        # ----- 2) Second window: 3D volume (pixel×height unproject) + target in red -----
        z_min = trajectory_3d[:, 2].min() - args.height_margin
        z_max = trajectory_3d[:, 2].max() + args.height_margin
        height_bucket_centers = np.linspace(z_min, z_max, N_HEIGHT_BINS)

        print('Building volume 3D points (pixel × height unproject)...')
        volume_pts = build_volume_3d_points(
            H, W, data['camera_pose'], data['cam_K'],
            height_bucket_centers,
            pixel_step=args.pixel_step
        )
        print(f'  Volume points: {len(volume_pts)}')

        fig2 = plt.figure(figsize=(10, 8))
        ax3d = fig2.add_subplot(111, projection='3d')

        # Non-target volume points: almost transparent white
        if len(volume_pts) > 0:
            ax3d.scatter(
                volume_pts[:, 0], volume_pts[:, 1], volume_pts[:, 2],
                c='blue', alpha=0.4, s=1, edgecolors='none'
            )

        # Target 3D trajectory: opaque red
        ax3d.scatter(
            trajectory_3d[:, 0], trajectory_3d[:, 1], trajectory_3d[:, 2],
            c='red', alpha=1.0, s=80, edgecolors='darkred', linewidths=1.5, label='GT trajectory'
        )

        ax3d.set_xlabel('X (m)')
        ax3d.set_ylabel('Y (m)')
        ax3d.set_zlabel('Z (m)')
        ax3d.set_title(f"{data['episode_id']} — Volume (white = pixel×height unproject, red = target 3D)")
        ax3d.legend()
        plt.tight_layout()
        plt.show()


if __name__ == '__main__':
    main()
