"""Render groundplane coordinate grid overlaid on RGB images."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))

import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import argparse

from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
from utils import project_3d_to_2d

# Hardcoded groundplane range constants (in meters)
# These limit the visualization to a region in front of the robot
# X is left/right, Z is forward/backward (in MuJoCo: X, Z, Y with Y up)
# Note: Negative Z is in front of robot, positive Z is behind
GROUNDPLANE_X_MIN = -0.2  # Left limit (meters)
GROUNDPLANE_X_MAX = 0.2   # Right limit (meters)
GROUNDPLANE_Z_MIN = -.5  # Forward limit (meters) - in front of robot (negative Z)
GROUNDPLANE_Z_MAX = -0.1   # Back limit (meters) - at robot position

# Height range constants (in meters)
# Y is up/down (in MuJoCo: X, Z, Y with Y up)
MIN_HEIGHT = -0.0   # Minimum height (meters)
MAX_HEIGHT = 0.2    # Maximum height (meters)

# Parse arguments
parser = argparse.ArgumentParser(description='Render groundplane coordinate grids on images')
parser.add_argument('--dataset_dir', '-d', default="scratch/parsed_rgb_joints_capture_desktop_train", type=str, help="Dataset directory")
parser.add_argument('--num_episodes', '-n', default=5, type=int, help="Number of episodes to visualize")
parser.add_argument('--grid_spacing', default=0.1, type=float, help="Grid spacing in meters")
parser.add_argument('--grid_range', default=1.0, type=float, help="Grid range in meters (from -range to +range)")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

if len(episode_dirs) == 0:
    print(f"No episodes found in {dataset_dir}")
    exit(1)

# Select episodes to visualize
num_episodes = min(args.num_episodes, len(episode_dirs))
selected_episodes = episode_dirs[:num_episodes]

# Generate groundplane grid coordinates (XZ plane, Y=0)
# Coordinates are [x, y, z] where Y is the last dimension (index 2)
grid_x = np.arange(-args.grid_range, args.grid_range + args.grid_spacing, args.grid_spacing)
grid_z = np.arange(-args.grid_range, args.grid_range + args.grid_spacing, args.grid_spacing)
grid_xx, grid_zz = np.meshgrid(grid_x, grid_z)
# Stack as [x, z, y=0] where y is at index 2 (last dimension)
grid_points_3d = np.stack([grid_xx.flatten(), grid_zz.flatten(), np.zeros_like(grid_xx.flatten())], axis=1)

# Create figure
fig, axes = plt.subplots(num_episodes, 2, figsize=(18, 6 * num_episodes))
if num_episodes == 1:
    axes = axes.reshape(1, -1)

fig.suptitle("Dense Groundplane Coordinate Map", fontsize=16, fontweight='bold')

for row, episode_dir in enumerate(tqdm(selected_episodes, desc="Processing episodes")):
    episode_id = episode_dir.name
    
    # Find all frame files
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) < 3:
        print(f"⚠ Episode {episode_id} has only {len(frame_files)} frames, need at least 3")
        continue
    
    # Select 3 frames to visualize (start, middle, end)
    frame_indices = [0]#[0, len(frame_files) // 2, len(frame_files) - 1]
    
    for col, frame_idx in enumerate(frame_indices):
        frame_file = frame_files[frame_idx]
        frame_str = f"{int(frame_file.stem):06d}"
        
        # Hardcode original image resolution (before downsampling)
        H_orig = 1080
        W_orig = 1920
        
        # Load RGB image
        rgb_np = cv2.cvtColor(cv2.imread(str(frame_file)), cv2.COLOR_BGR2RGB)
        if rgb_np.max() <= 1.0:
            rgb_np = (rgb_np * 255).astype(np.uint8)
        
        # Resize RGB to original resolution if needed (coord_colors is computed at H_orig x W_orig)
        H_loaded, W_loaded = rgb_np.shape[:2]
        if H_loaded != H_orig or W_loaded != W_orig:
            rgb_np = cv2.resize(rgb_np, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
        
        # Load camera pose and intrinsics
        camera_pose_path = episode_dir / f"{frame_str}_camera_pose.npy"
        cam_K_path = episode_dir / f"{frame_str}_cam_K.npy"
        if not camera_pose_path.exists() or not cam_K_path.exists():
            print(f"⚠ Episode {episode_id} frame {frame_str} missing camera pose or intrinsics")
            continue
        
        camera_pose = np.load(camera_pose_path)
        cam_K = np.load(cam_K_path)
        
        # cam_K is already calibrated at the original resolution (1080x1920)
        # No scaling needed - use it directly with H_orig=1080, W_orig=1920
        
        # Unproject all pixels to ground plane (vectorized)
        # Create meshgrid of all pixel coordinates
        u, v = np.meshgrid(np.arange(W_orig), np.arange(H_orig))
        u_flat = u.flatten()
        v_flat = v.flatten()
        
        # Convert pixels to normalized camera coordinates (vectorized)
        K_inv = np.linalg.inv(cam_K)
        pixels_h = np.stack([u_flat, v_flat, np.ones(len(u_flat))], axis=1).T  # (3, N)
        rays_cam = (K_inv @ pixels_h).T  # (N, 3)
        
        # Transform rays to world coordinates
        cam_pose_inv = np.linalg.inv(camera_pose)
        rays_world = (cam_pose_inv[:3, :3] @ rays_cam.T).T  # (N, 3) - direction vectors
        cam_pos = cam_pose_inv[:3, 3]  # (3,)
        
        # Check if rays pass through the valid 3D volume (vectorized)
        # Volume bounds: X in [GROUNDPLANE_X_MIN, GROUNDPLANE_X_MAX],
        #                Z in [GROUNDPLANE_Z_MIN, GROUNDPLANE_Z_MAX],
        #                Y in [MIN_HEIGHT, MAX_HEIGHT]
        # Ray equation: point = cam_pos + t * ray_dir
        
        # Exact ray-box intersection using slab method (vectorized)
        # For axis-aligned box, compute intersection intervals for each axis
        # Ray intersects box if all three intervals overlap
        
        # X axis intersections
        ray_dir_x = rays_world[:, 0]  # X component
        t_x_min = (GROUNDPLANE_X_MIN - cam_pos[0]) / (ray_dir_x + 1e-10)
        t_x_max = (GROUNDPLANE_X_MAX - cam_pos[0]) / (ray_dir_x + 1e-10)
        t_x_interval_min = np.minimum(t_x_min, t_x_max)
        t_x_interval_max = np.maximum(t_x_min, t_x_max)
        
        # Z axis intersections
        ray_dir_z = rays_world[:, 1]  # Z component
        t_z_min = (GROUNDPLANE_Z_MIN - cam_pos[1]) / (ray_dir_z + 1e-10)
        t_z_max = (GROUNDPLANE_Z_MAX - cam_pos[1]) / (ray_dir_z + 1e-10)
        t_z_interval_min = np.minimum(t_z_min, t_z_max)
        t_z_interval_max = np.maximum(t_z_min, t_z_max)
        
        # Y axis intersections (height)
        ray_dir_y = rays_world[:, 2]  # Y component
        t_y_min = (MIN_HEIGHT - cam_pos[2]) / (ray_dir_y + 1e-10)
        t_y_max = (MAX_HEIGHT - cam_pos[2]) / (ray_dir_y + 1e-10)
        t_y_interval_min = np.minimum(t_y_min, t_y_max)
        t_y_interval_max = np.maximum(t_y_min, t_y_max)
        
        # Find overlap of all three intervals
        # Ray intersects volume if: max(t_x_min, t_z_min, t_y_min) < min(t_x_max, t_z_max, t_y_max)
        t_enter = np.maximum(np.maximum(t_x_interval_min, t_z_interval_min), t_y_interval_min)  # (N,)
        t_exit = np.minimum(np.minimum(t_x_interval_max, t_z_interval_max), t_y_interval_max)  # (N,)
        
        # Ray intersects volume if t_enter < t_exit and t_exit > 0 (in front of camera)
        volume_mask_flat = (t_enter < t_exit) & (t_exit > 0)  # (N,)
        
        # For visualization, compute representative point at t_enter (entry point into volume)
        # For rays that don't intersect, use midpoint of intervals
        t_representative = np.where(volume_mask_flat, t_enter, (t_enter + t_exit) / 2)
        points_3d_flat = cam_pos[None, :] + t_representative[:, None] * rays_world  # (N, 3)
        
        # Reshape to image dimensions
        volume_mask = volume_mask_flat.reshape(H_orig, W_orig)
        points_3d = points_3d_flat.reshape(H_orig, W_orig, 3)
        
        # Extract X, Z, and Y coordinates for visualization
        x_coords = points_3d[:, :, 0]  # X coordinate
        z_coords = points_3d[:, :, 1]  # Z coordinate
        y_coords = points_3d[:, :, 2]  # Y coordinate (height)
        
        # Final valid mask: pixels whose rays pass through the volume
        final_valid_mask = volume_mask
        
        # Normalize coordinates to [0, 1] for color mapping (using hardcoded ranges)
        x_range = GROUNDPLANE_X_MAX - GROUNDPLANE_X_MIN
        z_range = GROUNDPLANE_Z_MAX - GROUNDPLANE_Z_MIN
        y_range = MAX_HEIGHT - MIN_HEIGHT
        x_norm = np.clip((x_coords - GROUNDPLANE_X_MIN) / (x_range + 1e-6), 0, 1)
        z_norm = np.clip((z_coords - GROUNDPLANE_Z_MIN) / (z_range + 1e-6), 0, 1)
        y_norm = np.clip((y_coords - MIN_HEIGHT) / (y_range + 1e-6), 0, 1)
        
        # Create RGB colors based on XZ coordinates and height:
        # X maps to red-yellow: red when x_norm=0, yellow when x_norm=1
        red_channel = 1.0 - x_norm * 0.5  # Full red at leftmost, half red at rightmost
        green_channel_x = x_norm  # More green at rightmost (for yellow)
        
        # Z maps to blue-green: blue when z_norm=0, green when z_norm=1
        blue_channel = 1.0 - z_norm  # Full blue at negative Z, no blue at positive Z
        green_channel_z = z_norm  # More green at positive Z
        
        # Height (Y) maps to a separate color channel or modulates brightness
        # Use height to add a purple/magenta tint: higher = more purple
        # Purple = red + blue, so we can add height to red and blue channels
        height_red = y_norm * 0.3  # Add some red for higher heights
        height_blue = y_norm * 0.5  # Add more blue for higher heights (purple effect)
        
        # Combine channels with height contribution
        coord_colors = np.stack([
            np.clip(red_channel + height_red, 0, 1),  # Red channel from X + height
            np.clip(green_channel_x + green_channel_z, 0, 1),  # Green channel from X and Z
            np.clip(blue_channel + height_blue, 0, 1)  # Blue channel from Z + height (purple effect)
        ], axis=2)  # (H, W, 3)
        
        # Set invalid pixels to white
        coord_colors[~final_valid_mask] = 1.0
        
        # Blend with RGB (50/50)
        rgb_normalized = rgb_np.astype(np.float32) / 255.0
        overlay = 0.5 * rgb_normalized + 0.5 * coord_colors
        
        # Load and project GT trajectory to groundplane
        kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
        trajectory_gt_3d = []
        for frame_file in frame_files:
            frame_idx_file = int(frame_file.stem)
            frame_str_file = f"{frame_idx_file:06d}"
            pose_path = episode_dir / f"{frame_str_file}_gripper_pose.npy"
            if not pose_path.exists():
                continue
            pose = np.load(pose_path)
            rot = pose[:3, :3]
            pos = pose[:3, 3]
            kp_3d = rot @ kp_local + pos
            trajectory_gt_3d.append(kp_3d)
        
        if len(trajectory_gt_3d) > 0:
            trajectory_gt_3d = np.array(trajectory_gt_3d)
            
            # Project to groundplane (set Y=0, where Y is index 2)
            trajectory_groundplane_3d = trajectory_gt_3d.copy()
            trajectory_groundplane_3d[:, 2] = 0.0  # Set Y to 0 (Y is index 2, last dimension)
            
            # Project groundplane trajectory to 2D
            trajectory_groundplane_2d = []
            for kp_3d_gp in trajectory_groundplane_3d:
                kp_2d = project_3d_to_2d(kp_3d_gp, camera_pose, cam_K)
                if kp_2d is not None:
                    trajectory_groundplane_2d.append(kp_2d)
            
            trajectory_groundplane_2d = np.array(trajectory_groundplane_2d) if len(trajectory_groundplane_2d) > 0 else None
        else:
            trajectory_groundplane_2d = None
        
        # Plot
        ax = axes[row, col]
        ax.imshow(overlay)
        
        # Draw groundplane trajectory on top
        if trajectory_groundplane_2d is not None and len(trajectory_groundplane_2d) > 0:
            # Filter to image bounds
            in_bounds = (trajectory_groundplane_2d[:, 0] >= 0) & (trajectory_groundplane_2d[:, 0] < W_orig) & \
                       (trajectory_groundplane_2d[:, 1] >= 0) & (trajectory_groundplane_2d[:, 1] < H_orig)
            if np.any(in_bounds):
                traj_final = trajectory_groundplane_2d[in_bounds]
                ax.plot(traj_final[:, 0], traj_final[:, 1], 
                       'w-', linewidth=2, alpha=0.8, label='Groundplane Traj', zorder=10)
                # Draw keypoints
                for i, (x, y) in enumerate(traj_final):
                    color = plt.cm.viridis(i / len(traj_final))
                    ax.plot(x, y, 'o', color=color, markersize=4, markeredgecolor='white', 
                           markeredgewidth=1, zorder=11)
        
        frame_label = f"Frame {frame_idx} ({frame_str})"
        ax.set_title(f"{episode_id} - {frame_label}", fontsize=10)
        ax.axis('off')

plt.tight_layout()
output_path = Path(f'groundplane_testing/render_dense_groundplane.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved {output_path}")
plt.show()
