"""Visualize filtered 2D gripper trajectories showing key extrema points."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))

import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import argparse

from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
from utils import project_3d_to_2d, rescale_coords

# Parse arguments
parser = argparse.ArgumentParser(description='Visualize filtered 2D gripper trajectories (close, open, end)')
parser.add_argument('--dataset_dir', '-d', default="scratch/parsed_rgb_joints_capture_desktop_train", type=str, help="Dataset directory")
parser.add_argument('--num_episodes', '-n', default=5, type=int, help="Number of episodes to visualize")
parser.add_argument('--close_threshold', type=float, default=0.1, help="Gripper close threshold")
parser.add_argument('--open_threshold', type=float, default=0.4, help="Gripper open threshold")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

if len(episode_dirs) == 0:
    print(f"No episodes found in {dataset_dir}")
    exit(1)

# Select episodes to visualize
num_episodes = min(args.num_episodes, len(episode_dirs))
selected_episodes = episode_dirs[:num_episodes]

kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Create figure
fig, axes = plt.subplots(num_episodes, 2, figsize=(20, 6 * num_episodes))
if num_episodes == 1:
    axes = axes.reshape(1, -1)

fig.suptitle(f"Filtered 2D Gripper Trajectories (Close<{args.close_threshold}, Open>{args.open_threshold}, End)", fontsize=16, fontweight='bold')

for row, episode_dir in enumerate(tqdm(selected_episodes, desc="Processing episodes")):
    episode_id = episode_dir.name
    
    # Find all frame files
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) < 1:
        print(f"⚠ Episode {episode_id} has no frames, skipping")
        continue
    
    # Use first frame for RGB image (starting frame)
    start_frame_file = frame_files[0]
    start_frame_idx = int(start_frame_file.stem)
    start_frame_str = f"{start_frame_idx:06d}"
    
    # Load camera pose and intrinsics (from starting frame)
    camera_pose_path = episode_dir / f"{start_frame_str}_camera_pose.npy"
    cam_K_path = episode_dir / f"{start_frame_str}_cam_K.npy"
    if not camera_pose_path.exists() or not cam_K_path.exists():
        print(f"⚠ Episode {episode_id} frame {start_frame_str} missing camera pose or intrinsics")
        continue
    
    camera_pose = np.load(camera_pose_path)
    cam_K = np.load(cam_K_path)
    
    # Hardcode original image resolution (before downsampling)
    H_orig = 1080
    W_orig = 1920
    
    # Load RGB image from starting frame
    rgb_np = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
    if rgb_np.max() <= 1.0:
        rgb_np = (rgb_np * 255).astype(np.uint8)
    
    # Scale camera intrinsics if image resolution differs
    H_loaded, W_loaded = rgb_np.shape[:2]
    if H_loaded != H_orig or W_loaded != W_orig:
        scale_x = W_loaded / W_orig
        scale_y = H_loaded / H_orig
        cam_K_scaled = cam_K.copy()
        cam_K_scaled[0, 0] *= scale_x  # fx
        cam_K_scaled[1, 1] *= scale_y  # fy
        cam_K_scaled[0, 2] *= scale_x  # cx
        cam_K_scaled[1, 2] *= scale_y  # cy
        cam_K = cam_K_scaled
        H_orig = H_loaded
        W_orig = W_loaded
    
    # Load entire GT trajectory (all frames) with gripper values
    trajectory_gt_3d = []
    gripper_values = []
    frame_indices_list = []
    
    for frame_file in frame_files:
        frame_idx = int(frame_file.stem)
        frame_str = f"{frame_idx:06d}"
        pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if not pose_path.exists():
            continue
        
        pose = np.load(pose_path)
        rot = pose[:3, :3]
        pos = pose[:3, 3]
        kp_3d = rot @ kp_local + pos
        trajectory_gt_3d.append(kp_3d)
        frame_indices_list.append(frame_idx)
        
        # Load gripper value from joint state file
        joint_state_path = episode_dir / f"{frame_str}.npy"
        if joint_state_path.exists():
            joint_state = np.load(joint_state_path)
            gripper_value = float(joint_state[-1])  # Last value is gripper open/close
            gripper_values.append(gripper_value)
        else:
            # If joint state not found, use 1.0 as default (fully open)
            gripper_values.append(1.0)
    
    trajectory_gt_3d = np.array(trajectory_gt_3d) if len(trajectory_gt_3d) > 0 else np.array([]).reshape(0, 3)
    gripper_values = np.array(gripper_values)
    
    if len(trajectory_gt_3d) == 0:
        print(f"⚠ Episode {episode_id} no valid trajectory")
        continue
    
    # Find key extrema points:
    # 1. First point where gripper closes beyond threshold (gripper < close_threshold)
    # 2. First point where gripper opens beyond threshold (gripper > open_threshold) after closing
    # 3. Last point in trajectory
    
    extrema_indices = []
    
    # Find first close point
    close_idx = None
    for i, gv in enumerate(gripper_values):
        if gv < args.close_threshold:
            close_idx = i
            extrema_indices.append(i)
            break
    
    # Find first open point after closing
    open_idx = None
    if close_idx is not None:
        for i in range(close_idx + 1, len(gripper_values)):
            if gripper_values[i] > args.open_threshold:
                open_idx = i
                extrema_indices.append(i)
                break
    
    # Always include last point
    if len(trajectory_gt_3d) > 0:
        last_idx = len(trajectory_gt_3d) - 1
        if last_idx not in extrema_indices:
            extrema_indices.append(last_idx)
    
    # Sort indices to maintain temporal order
    extrema_indices = sorted(extrema_indices)
    
    if len(extrema_indices) == 0:
        print(f"⚠ Episode {episode_id} no extrema points found")
        continue
    
    # Filter trajectory to extrema points
    trajectory_extrema_3d = trajectory_gt_3d[extrema_indices]
    gripper_extrema = gripper_values[extrema_indices]
    
    # Project filtered 3D keypoints to 2D
    trajectory_2d = []
    for kp_3d in trajectory_extrema_3d:
        kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
        if kp_2d is not None:
            # Clip to image bounds
            kp_2d_clipped = np.array([
                np.clip(kp_2d[0], 0, W_orig - 1),
                np.clip(kp_2d[1], 0, H_orig - 1)
            ])
            trajectory_2d.append(kp_2d_clipped)
        else:
            # If projection fails, use previous valid point or skip
            if len(trajectory_2d) > 0:
                trajectory_2d.append(trajectory_2d[-1])
            else:
                continue
    
    trajectory_2d = np.array(trajectory_2d) if len(trajectory_2d) > 0 else None
    
    if trajectory_2d is None or len(trajectory_2d) == 0:
        print(f"⚠ Episode {episode_id} no valid 2D projections for extrema")
        continue
    
    # Plot 1: RGB image with filtered 2D gripper trajectory overlaid
    ax1 = axes[row, 0]
    ax1.imshow(rgb_np)
    
    # Plot trajectory line connecting extrema points
    if len(trajectory_2d) > 1:
        ax1.plot(trajectory_2d[:, 0], trajectory_2d[:, 1], 
               'g-', linewidth=3, alpha=0.8, label='Filtered Trajectory', zorder=10)
    
    # Plot extrema points with specific markers and colors
    point_labels = []
    for i, (x, y) in enumerate(trajectory_2d):
        orig_idx = extrema_indices[i]
        gv = gripper_extrema[i]
        
        # Determine point type and styling
        if orig_idx == close_idx:
            color = 'red'
            marker = 's'
            label_text = 'Close'
            markersize = 14
        elif orig_idx == open_idx:
            color = 'blue'
            marker = 's'
            label_text = 'Open'
            markersize = 14
        elif orig_idx == len(trajectory_gt_3d) - 1:
            color = 'magenta'
            marker = 's'
            label_text = 'End'
            markersize = 14
        else:
            color = 'yellow'
            marker = 'o'
            label_text = f'Pt{orig_idx}'
            markersize = 10
        
        ax1.plot(x, y, marker, color=color, markersize=markersize, 
               markeredgecolor='white', markeredgewidth=2, zorder=11, alpha=0.9)
        
        # Add text label
        ax1.text(x, y - 20, label_text, color='white', fontsize=9, 
                ha='center', va='top', weight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.7),
                zorder=13)
        
        point_labels.append(f"{label_text} (g={gv:.2f})")
    
    ax1.set_title(f"Episode {episode_id} - Filtered Extrema\n({len(trajectory_2d)} points: {', '.join(point_labels)})", fontsize=11)
    ax1.legend(loc='upper right', fontsize=8)
    ax1.axis('off')
    
    # Plot 2: Show full trajectory with extrema highlighted
    ax2 = axes[row, 1]
    ax2.imshow(rgb_np)
    
    # Project all points for full trajectory visualization
    trajectory_full_2d = []
    for kp_3d in trajectory_gt_3d:
        kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
        if kp_2d is not None:
            kp_2d_clipped = np.array([
                np.clip(kp_2d[0], 0, W_orig - 1),
                np.clip(kp_2d[1], 0, H_orig - 1)
            ])
            trajectory_full_2d.append(kp_2d_clipped)
    
    if len(trajectory_full_2d) > 1:
        trajectory_full_2d = np.array(trajectory_full_2d)
        # Plot full trajectory in light gray
        ax2.plot(trajectory_full_2d[:, 0], trajectory_full_2d[:, 1], 
                'gray', linewidth=1, alpha=0.3, label='Full Trajectory', zorder=5)
    
    # Plot filtered trajectory line
    if len(trajectory_2d) > 1:
        ax2.plot(trajectory_2d[:, 0], trajectory_2d[:, 1], 
                'g-', linewidth=3, alpha=0.8, label='Filtered Trajectory', zorder=10)
    
    # Highlight extrema points
    for i, (x, y) in enumerate(trajectory_2d):
        orig_idx = extrema_indices[i]
        gv = gripper_extrema[i]
        
        if orig_idx == close_idx:
            color = 'red'
            marker = 's'
            label_text = f'Close\n(g={gv:.2f})'
        elif orig_idx == open_idx:
            color = 'blue'
            marker = 's'
            label_text = f'Open\n(g={gv:.2f})'
        elif orig_idx == len(trajectory_gt_3d) - 1:
            color = 'magenta'
            marker = 's'
            label_text = f'End\n(g={gv:.2f})'
        else:
            color = 'yellow'
            marker = 'o'
            label_text = f'Pt{orig_idx}\n(g={gv:.2f})'
        
        ax2.plot(x, y, marker, color=color, markersize=12, 
                markeredgecolor='white', markeredgewidth=2, zorder=11, alpha=0.9)
        ax2.text(x, y - 25, label_text, color='white', fontsize=8, 
                ha='center', va='top', weight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.7),
                zorder=13)
    
    ax2.set_title(f"Episode {episode_id} - Full vs Filtered\n({len(trajectory_gt_3d)} total, {len(trajectory_2d)} extrema)", fontsize=11)
    ax2.legend(loc='upper right', fontsize=8)
    ax2.axis('off')

plt.tight_layout()
output_path = Path(f'keypoint_testing2/vis_extrema_trajectories.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved {output_path}")
plt.show()
