"""Render groundplane coordinate grid overlaid on RGB images."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))

import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import argparse

from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
from utils import project_3d_to_2d

# Parse arguments
parser = argparse.ArgumentParser(description='Render groundplane coordinate grids on images')
parser.add_argument('--dataset_dir', '-d', default="scratch/parsed_rgb_joints_capture_desktop_train", type=str, help="Dataset directory")
parser.add_argument('--num_episodes', '-n', default=5, type=int, help="Number of episodes to visualize")
parser.add_argument('--grid_spacing', default=0.1, type=float, help="Grid spacing in meters")
parser.add_argument('--grid_range', default=1.0, type=float, help="Grid range in meters (from -range to +range)")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

if len(episode_dirs) == 0:
    print(f"No episodes found in {dataset_dir}")
    exit(1)

# Select episodes to visualize
num_episodes = min(args.num_episodes, len(episode_dirs))
selected_episodes = episode_dirs[:num_episodes]

# Generate groundplane grid coordinates (XZ plane, Y=0)
# Coordinates are [x, y, z] where Y is the last dimension (index 2)
grid_x = np.arange(-args.grid_range, args.grid_range + args.grid_spacing, args.grid_spacing)
grid_z = np.arange(-args.grid_range, args.grid_range + args.grid_spacing, args.grid_spacing)
grid_xx, grid_zz = np.meshgrid(grid_x, grid_z)
# Stack as [x, z, y=0] where y is at index 2 (last dimension)
grid_points_3d = np.stack([grid_xx.flatten(), grid_zz.flatten(), np.zeros_like(grid_xx.flatten())], axis=1)

# Create figure
fig, axes = plt.subplots(num_episodes, 3, figsize=(18, 6 * num_episodes))
if num_episodes == 1:
    axes = axes.reshape(1, -1)

fig.suptitle("Groundplane Coordinate Grid Overlay", fontsize=16, fontweight='bold')

for row, episode_dir in enumerate(tqdm(selected_episodes, desc="Processing episodes")):
    episode_id = episode_dir.name
    
    # Find all frame files
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) < 3:
        print(f"⚠ Episode {episode_id} has only {len(frame_files)} frames, need at least 3")
        continue
    
    # Select 3 frames to visualize (start, middle, end)
    frame_indices = [0, len(frame_files) // 2, len(frame_files) - 1]
    
    for col, frame_idx in enumerate(frame_indices):
        frame_file = frame_files[frame_idx]
        frame_str = f"{int(frame_file.stem):06d}"
        
        # Hardcode original image resolution (before downsampling)
        H_orig = 1080
        W_orig = 1920
        
        # Load RGB image
        rgb_np = cv2.cvtColor(cv2.imread(str(frame_file)), cv2.COLOR_BGR2RGB)
        if rgb_np.max() <= 1.0:
            rgb_np = (rgb_np * 255).astype(np.uint8)
        
        # Load camera pose and intrinsics
        camera_pose_path = episode_dir / f"{frame_str}_camera_pose.npy"
        cam_K_path = episode_dir / f"{frame_str}_cam_K.npy"
        if not camera_pose_path.exists() or not cam_K_path.exists():
            print(f"⚠ Episode {episode_id} frame {frame_str} missing camera pose or intrinsics")
            continue
        
        camera_pose = np.load(camera_pose_path)
        cam_K = np.load(cam_K_path)
        
        # cam_K is already calibrated at the original resolution (1080x1920)
        # No scaling needed - use it directly with H_orig=1080, W_orig=1920
        
        # Project grid points to 2D and keep track of grid structure (vectorized)
        # Note: grid_points_3d is [x, z, y=0] where y is index 2 (last dimension)
        num_x = len(grid_x)
        num_z = len(grid_z)
        grid_2d_reshaped = np.full((num_z, num_x, 2), np.nan)
        
        # Ensure Y=0 for all points (vectorized)
        grid_points_3d_flat = grid_points_3d.copy()
        grid_points_3d_flat[:, 2] = 0.0  # Set Y=0 (index 2, last dimension)
        
        # Project all points at once (vectorized)
        # Convert to homogeneous coordinates
        points_3d_h = np.column_stack([grid_points_3d_flat, np.ones(len(grid_points_3d_flat))])
        
        # Transform to camera coordinates
        points_cam = (camera_pose @ points_3d_h.T).T[:, :3]  # (N, 3)
        
        # Filter points in front of camera
        valid_mask = points_cam[:, 2] > 0
        
        # Project to 2D for valid points
        points_2d_h = (cam_K @ points_cam[valid_mask].T).T  # (N_valid, 3)
        points_2d = points_2d_h[:, :2] / points_2d_h[:, 2:3]  # (N_valid, 2)
        
        # Fill in valid projections
        valid_indices = np.where(valid_mask)[0]
        for idx, point_2d_val in zip(valid_indices, points_2d):
            z_idx = idx // num_x
            x_idx = idx % num_x
            grid_2d_reshaped[z_idx, x_idx, :] = point_2d_val
        
        # Plot
        ax = axes[row, col]
        ax.imshow(rgb_np)
        
        # Draw grid lines (horizontal lines - constant Z)
        for z_idx in range(num_z):
            line_points = grid_2d_reshaped[z_idx, :, :]
            valid_mask = ~np.isnan(line_points[:, 0])
            if np.sum(valid_mask) > 1:
                line_points_valid = line_points[valid_mask]
                # Filter to image bounds
                in_bounds = (line_points_valid[:, 0] >= 0) & (line_points_valid[:, 0] < W_orig) & \
                           (line_points_valid[:, 1] >= 0) & (line_points_valid[:, 1] < H_orig)
                if np.sum(in_bounds) > 1:
                    line_points_final = line_points_valid[in_bounds]
                    ax.plot(line_points_final[:, 0], line_points_final[:, 1], 
                           'b-', linewidth=0.5, alpha=0.6)
        
        # Draw grid lines (vertical lines - constant X)
        for x_idx in range(num_x):
            line_points = grid_2d_reshaped[:, x_idx, :]
            valid_mask = ~np.isnan(line_points[:, 0])
            if np.sum(valid_mask) > 1:
                line_points_valid = line_points[valid_mask]
                # Filter to image bounds
                in_bounds = (line_points_valid[:, 0] >= 0) & (line_points_valid[:, 0] < W_orig) & \
                           (line_points_valid[:, 1] >= 0) & (line_points_valid[:, 1] < H_orig)
                if np.sum(in_bounds) > 1:
                    line_points_final = line_points_valid[in_bounds]
                    ax.plot(line_points_final[:, 0], line_points_final[:, 1], 
                           'b-', linewidth=0.5, alpha=0.6)
        
        # Draw grid points (all valid points) colored by 3D coordinates
        all_points_2d = grid_2d_reshaped.reshape(-1, 2)
        all_points_3d_flat = grid_points_3d_flat  # Already has Y=0 set
        
        # Get valid 2D points and their corresponding 3D coordinates
        valid_2d_mask = ~np.isnan(all_points_2d[:, 0])
        valid_points_2d = all_points_2d[valid_2d_mask]
        valid_points_3d = all_points_3d_flat[valid_2d_mask]
        
        if len(valid_points_2d) > 0:
            # Filter to image bounds
            in_bounds = (valid_points_2d[:, 0] >= 0) & (valid_points_2d[:, 0] < W_orig) & \
                       (valid_points_2d[:, 1] >= 0) & (valid_points_2d[:, 1] < H_orig)
            valid_points_2d_final = valid_points_2d[in_bounds]
            valid_points_3d_final = valid_points_3d[in_bounds]
            
            if len(valid_points_2d_final) > 0:
                # Extract X and Z coordinates (coordinates are [x, z, y] where y is index 2)
                x_coords = valid_points_3d_final[:, 0]  # X coordinate
                z_coords = valid_points_3d_final[:, 1]  # Z coordinate
                
                # Normalize coordinates to [0, 1] for color mapping
                x_min, x_max = x_coords.min(), x_coords.max()
                z_min, z_max = z_coords.min(), z_coords.max()
                
                x_norm = (x_coords - x_min) / (x_max - x_min + 1e-6)
                z_norm = (z_coords - z_min) / (z_max - z_min + 1e-6)
                
                # Create RGB colors based on XZ coordinates:
                # Red channel: X coordinate (leftmost=red, rightmost=yellow)
                # Green channel: Z coordinate (negative Z=blue, positive Z=green)
                # Blue channel: inverse of Z (for blue-green gradient)
                
                # X maps to red-yellow: red when x_norm=0, yellow when x_norm=1
                # Yellow = (1, 1, 0) = red + green
                red_channel = 1.0 - x_norm * 0.5  # Full red at leftmost, half red at rightmost
                green_channel_x = x_norm  # More green at rightmost (for yellow)
                
                # Z maps to blue-green: blue when z_norm=0, green when z_norm=1
                blue_channel = 1.0 - z_norm  # Full blue at negative Z, no blue at positive Z
                green_channel_z = z_norm  # More green at positive Z
                
                # Combine channels
                colors = np.stack([
                    red_channel,  # Red channel from X
                    np.clip(green_channel_x + green_channel_z, 0, 1),  # Green channel from X and Z
                    blue_channel  # Blue channel from Z
                ], axis=1)
                
                ax.scatter(valid_points_2d_final[:, 0], valid_points_2d_final[:, 1], 
                          c=colors, s=2, alpha=0.8, zorder=5)
        
        # Draw origin (0, 0, 0) if visible
        origin_2d = project_3d_to_2d(np.array([0.0, 0.0, 0.0]), camera_pose, cam_K)
        if origin_2d is not None and 0 <= origin_2d[0] < W_orig and 0 <= origin_2d[1] < H_orig:
            ax.plot(origin_2d[0], origin_2d[1], 'go', markersize=10, markeredgecolor='white', 
                   markeredgewidth=2, label='Origin (0,0,0)', zorder=10)
        
        frame_label = f"Frame {frame_idx} ({frame_str})"
        ax.set_title(f"{episode_id} - {frame_label}", fontsize=10)
        ax.axis('off')
        if row == 0 and col == 0:
            ax.legend(loc='upper right', fontsize=8)

plt.tight_layout()
output_path = Path(f'groundplane_testing/render_groundplane.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved {output_path}")
plt.show()
