"""Render groundplane (XZ plane) trajectories for episodes in a dataset."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import argparse

from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX

# Parse arguments
parser = argparse.ArgumentParser(description='Render groundplane trajectories')
parser.add_argument('--dataset_dir', '-d', default="scratch/parsed_rgb_joints_capture_desktop_train", type=str, help="Dataset directory")
parser.add_argument('--num_episodes', '-n', default=None, type=int, help="Number of episodes to visualize (None = all)")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

if len(episode_dirs) == 0:
    print(f"No episodes found in {dataset_dir}")
    exit(1)

# Select episodes to visualize
if args.num_episodes is not None:
    num_episodes = min(args.num_episodes, len(episode_dirs))
    selected_episodes = episode_dirs[:num_episodes]
else:
    selected_episodes = episode_dirs
    num_episodes = len(selected_episodes)

kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.set_aspect('equal')

# Color map for different episodes
colors = plt.cm.tab20(np.linspace(0, 1, num_episodes))

for idx, episode_dir in enumerate(tqdm(selected_episodes, desc="Processing episodes")):
    episode_id = episode_dir.name
    
    # Find all frame files
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) < 1:
        print(f"⚠ Episode {episode_id} has no frames, skipping")
        continue
    
    # Load entire GT trajectory (all frames)
    trajectory_gt_3d = []
    for frame_file in frame_files:
        frame_idx = int(frame_file.stem)
        frame_str = f"{frame_idx:06d}"
        pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if not pose_path.exists():
            continue
        pose = np.load(pose_path)
        rot = pose[:3, :3]
        pos = pose[:3, 3]
        kp_3d = rot @ kp_local + pos
        trajectory_gt_3d.append(kp_3d)
    
    if len(trajectory_gt_3d) == 0:
        print(f"⚠ Episode {episode_id} no valid trajectory")
        continue
    
    trajectory_gt_3d = np.array(trajectory_gt_3d)
    
    # Project to groundplane (XZ plane) by dropping Y component
    # Coordinates are [x, y, z] where Y is the last dimension (index 2)
    trajectory_groundplane = trajectory_gt_3d.copy()
    trajectory_groundplane[:, 2] = 0.0  # Set Y to 0 (Y is index 2, the last dimension)
    
    # Extract X and Z coordinates (indices 0 and 1)
    x_coords = trajectory_groundplane[:, 0]
    z_coords = trajectory_groundplane[:, 1]
    
    # Plot trajectory
    color = colors[idx]
    ax.plot(x_coords, z_coords, '-', linewidth=2, alpha=0.7, color=color, label=episode_id)
    ax.scatter(x_coords[0], z_coords[0], s=100, color=color, marker='o', edgecolors='black', linewidths=2, zorder=5, label=f'{episode_id} start')
    ax.scatter(x_coords[-1], z_coords[-1], s=100, color=color, marker='s', edgecolors='black', linewidths=2, zorder=5, label=f'{episode_id} end')

ax.set_xlabel('X (meters)', fontsize=12)
ax.set_ylabel('Z (meters)', fontsize=12)
ax.set_title('Groundplane (XZ) Trajectories', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8, ncol=2)

plt.tight_layout()
output_path = Path(f'groundplane_testing/render_groundplane.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved {output_path}")
plt.show()
