"""Manual pick and place testing: click pick/place locations, use median heights, do IK and render."""
import sys
import os
from pathlib import Path
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import argparse
import xml.etree.ElementTree as ET
import mink

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "."))

from robot_models.so100_controller import Arm
import mujoco
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot
from data import MIN_HEIGHT, MAX_HEIGHT
from utils import project_3d_to_2d, recover_3d_from_direct_keypoint_and_height, ik_to_keypoint_and_rotation

# Hardcoded median rotation (like live_cam_inference.py)
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

# Parse arguments
parser = argparse.ArgumentParser(description='Manual pick and place testing')
parser.add_argument('--episode_idx', '-e', default=0, type=int, help='Episode index')
parser.add_argument('--start_frame', '-sf', default=0, type=int, help='Start frame')
parser.add_argument('--dataset_dir', '-d', type=str, required=True, help='Dataset directory')
parser.add_argument('--waypoint_density', type=int, default=5, help='Number of waypoints per segment (default: 5)')
parser.add_argument('--arc_height', type=float, default=0.05, help='Height offset for arc trajectory in meters (default: 0.05m)')
parser.add_argument('--use_saved_pts', action='store_true', help='Use hard-coded pick/place points instead of clicking')
args = parser.parse_args()

# Load dataset
dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])
if len(episode_dirs) == 0:
    print(f"No episodes found in {dataset_dir}")
    exit(1)

# Select one episode
episode_dir = episode_dirs[args.episode_idx]
episode_id = episode_dir.name
print(f"Using episode: {episode_id}")

frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
if len(frame_files) == 0:
    print(f"No frames found in {episode_id}")
    exit(1)

# Select start frame
start_idx = args.start_frame
if start_idx >= len(frame_files):
    print(f"Start frame {start_idx} out of range (max: {len(frame_files)-1})")
    exit(1)

start_frame_file = frame_files[start_idx]
frame_idx = int(start_frame_file.stem)
start_frame_str = f"{frame_idx:06d}"

# Load start frame RGB
print(f"Loading image from {start_frame_file}...")
rgb = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)
H_loaded, W_loaded = rgb.shape[:2]
print(f"Image resolution: {W_loaded}x{H_loaded}")

# Hardcode original image resolution (before downsampling)
H_orig = 1080
W_orig = 1920

# Load camera pose and intrinsics (use start frame)
# Try different naming conventions (same as vis_eval_2d.py)
camera_pose_path = episode_dir / f"{start_frame_str}_camera_pose.npy"
cam_K_path = episode_dir / f"{start_frame_str}_cam_K.npy"

if not camera_pose_path.exists() or not cam_K_path.exists():
    # Try alternative naming convention
    camera_pose_path = episode_dir / f"robot_camera_pose_{start_frame_str}.npy"
    cam_K_path = episode_dir / f"cam_K_{start_frame_str}.npy"
    
if not camera_pose_path.exists() or not cam_K_path.exists():
    # Try yet another convention (first frame)
    first_frame_str = f"{int(frame_files[0].stem):06d}"
    camera_pose_path = episode_dir / f"{first_frame_str}_camera_pose.npy"
    cam_K_path = episode_dir / f"{first_frame_str}_cam_K.npy"

if not camera_pose_path.exists() or not cam_K_path.exists():
    print(f"Camera data not found for {episode_id} at frame {start_frame_str}")
    exit(1)

camera_pose = np.load(camera_pose_path)
cam_K = np.load(cam_K_path)
print(f"✓ Loaded camera pose from {camera_pose_path}")
print(f"✓ Loaded camera intrinsics from {cam_K_path}")

# Compute median pick and place heights from dataset
print(f"\nComputing median pick and place heights from dataset: {dataset_dir}...")
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

pick_heights = []  # Grasp heights (first timestep)
place_heights = []  # Ungrasp heights (second timestep)

for episode_dir in episode_dirs:
    frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
    if len(frame_files) == 0:
        continue
    
    # Load trajectory data
    trajectory_gt_3d = []
    gripper_values = []
    
    for frame_file in frame_files:
        frame_str = f"{int(frame_file.stem):06d}"
        pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        if pose_path.exists():
            gripper_pose = np.load(pose_path)
            kp_local = np.array([0, 0, 0])  # Use gripper center
            rot = gripper_pose[:3, :3]
            pos = gripper_pose[:3, 3]
            kp_3d = rot @ kp_local + pos
            trajectory_gt_3d.append(kp_3d)
            
            # Load gripper value
            joint_state_path = episode_dir / f"{frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])
            else:
                gripper_value = 1.0
            gripper_values.append(gripper_value)
    
    if len(trajectory_gt_3d) < 2:
        continue
    
    trajectory_gt_3d = np.array(trajectory_gt_3d)
    gripper_values = np.array(gripper_values)
    
    # Find grasp extrema: first close, then first open after close
    GRASP_CLOSE_THRESHOLD = 0.1
    GRASP_OPEN_THRESHOLD = 0.4
    
    grasp_idx = None
    for i in range(len(gripper_values)):
        if i > 4 and gripper_values[i] < GRASP_CLOSE_THRESHOLD:
            grasp_idx = i
            break
    
    if grasp_idx is not None:
        # Extract height from grasp location (index 2 is height in [x, z, y] convention)
        grasp_height = trajectory_gt_3d[grasp_idx][2]
        pick_heights.append(grasp_height)
        
        # Find ungrasp (first open after grasp)
        ungrasp_idx = None
        for i in range(grasp_idx + 1, len(gripper_values)):
            if gripper_values[i] > GRASP_OPEN_THRESHOLD:
                ungrasp_idx = i
                break
        
        if ungrasp_idx is not None:
            # Extract height from ungrasp location (index 2 is height in [x, z, y] convention)
            ungrasp_height = trajectory_gt_3d[ungrasp_idx][2]
            place_heights.append(ungrasp_height)

if len(pick_heights) == 0 or len(place_heights) == 0:
    print(f"⚠ Warning: Found {len(pick_heights)} pick heights and {len(place_heights)} place heights")
    print("Using default heights")
    median_pick_height = 0.0
    median_place_height = 0.1
else:
    median_pick_height = np.median(pick_heights)
    median_place_height = np.median(place_heights)
    print(f"✓ Computed median pick height: {median_pick_height:.4f}m")
    print(f"✓ Computed median place height: {median_place_height:.4f}m")

# Get user clicks for pick and place locations
if args.use_saved_pts:
    # Use hard-coded points from earlier successful run (at loaded image resolution)
    # Original points at 2x resolution: Pick=(1229.5, 663.9), Place=(1696.0, 433.7)
    pick_2d = np.array([614.75, 331.95])
    place_2d = np.array([848.0, 216.85])
    print("\nUsing saved pick and place locations:")
    print(f"  Pick location (image coords): ({pick_2d[0]:.1f}, {pick_2d[1]:.1f})")
    print(f"  Place location (image coords): ({place_2d[0]:.1f}, {place_2d[1]:.1f})")
else:
    print("\nClick on the image to select:")
    print("  1. Pick location (first click)")
    print("  2. Place location (second click)")
    print("Close the window when done.")

    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(rgb)
    ax.set_title("Click: 1) Pick location, 2) Place location")
    ax.axis('off')

    # Get clicks
    clicks = plt.ginput(2, timeout=0)
    plt.close()

    if len(clicks) < 2:
        print("Error: Need 2 clicks (pick and place)")
        exit(1)

    pick_2d = np.array([clicks[0][0], clicks[0][1]])  # (x, y)
    place_2d = np.array([clicks[1][0], clicks[1][1]])  # (x, y)

    print(f"Pick location (image coords): ({pick_2d[0]:.1f}, {pick_2d[1]:.1f})")
    print(f"Place location (image coords): ({place_2d[0]:.1f}, {place_2d[1]:.1f})")
    
    # Check if pick and place are too close (likely accidental double-click)
    distance = np.linalg.norm(pick_2d - place_2d)
    if distance < 5.0:  # Less than 5 pixels apart
        print(f"\n⚠ Warning: Pick and place locations are very close ({distance:.1f} pixels apart)")
        print("This might be an accidental double-click. Consider re-running.")

# Convert 2D locations to 3D using median heights
# Rescale from loaded image to original resolution
pick_2d_orig = pick_2d * np.array([W_orig / W_loaded, H_orig / H_loaded])
place_2d_orig = place_2d * np.array([W_orig / W_loaded, H_orig / H_loaded])

print(f"\nLifting 2D to 3D:")
print(f"Pick: 2D=({pick_2d_orig[0]:.1f}, {pick_2d_orig[1]:.1f}), height={median_pick_height:.4f}m (median from grasp locations)")
print(f"Place: 2D=({place_2d_orig[0]:.1f}, {place_2d_orig[1]:.1f}), height={median_place_height:.4f}m (median from ungrasp locations)")

pick_3d = recover_3d_from_direct_keypoint_and_height(pick_2d_orig, median_pick_height, camera_pose, cam_K)
place_3d = recover_3d_from_direct_keypoint_and_height(place_2d_orig, median_place_height, camera_pose, cam_K)

if pick_3d is None or place_3d is None:
    print("Error: Failed to recover 3D locations")
    exit(1)

print(f"Pick location (3D): ({pick_3d[0]:.4f}, {pick_3d[1]:.4f}, {pick_3d[2]:.4f}) [using median_pick_height={median_pick_height:.4f}m]")
print(f"Place location (3D): ({place_3d[0]:.4f}, {place_3d[1]:.4f}, {place_3d[2]:.4f}) [using median_place_height={median_place_height:.4f}m]")

# Reproject 3D coordinates back to 2D to verify accuracy
pick_2d_reproj = project_3d_to_2d(pick_3d, camera_pose, cam_K)
place_2d_reproj = project_3d_to_2d(place_3d, camera_pose, cam_K)

pick_2d_reproj_loaded = None
place_2d_reproj_loaded = None

if pick_2d_reproj is not None and place_2d_reproj is not None:
    # Rescale from original resolution to loaded image resolution
    pick_2d_reproj_loaded = pick_2d_reproj * np.array([W_loaded / W_orig, H_loaded / H_orig])
    place_2d_reproj_loaded = place_2d_reproj * np.array([W_loaded / W_orig, H_loaded / H_orig])
    
    pick_reproj_error = np.linalg.norm(pick_2d - pick_2d_reproj_loaded)
    place_reproj_error = np.linalg.norm(place_2d - place_2d_reproj_loaded)
    
    print(f"\nReprojection check:")
    print(f"Pick - Original: ({pick_2d[0]:.1f}, {pick_2d[1]:.1f}), Reprojected: ({pick_2d_reproj_loaded[0]:.1f}, {pick_2d_reproj_loaded[1]:.1f}), Error: {pick_reproj_error:.2f} pixels")
    print(f"Place - Original: ({place_2d[0]:.1f}, {place_2d[1]:.1f}), Reprojected: ({place_2d_reproj_loaded[0]:.1f}, {place_2d_reproj_loaded[1]:.1f}), Error: {place_reproj_error:.2f} pixels")
    
    if pick_reproj_error > 10 or place_reproj_error > 10:
        print("⚠ Warning: Large reprojection error - check camera pose, intrinsics, or lifting function")
else:
    print("⚠ Warning: Failed to reproject 3D coordinates back to 2D")

# Setup robot
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Load actual robot state from dataset at the specified frame
joint_state_path = episode_dir / f"{frame_str}.npy"
if joint_state_path.exists():
    joint_state = np.load(joint_state_path)
    mj_data.qpos[:] = joint_state
    print(f"\n✓ Loaded robot joint state from {joint_state_path.name}")
else:
    # Fallback to default state if not found
    mj_data.qpos[:] = 0.0
    print(f"\n⚠ No joint state found at {joint_state_path}, using default state")

mujoco.mj_forward(mj_model, mj_data)

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model)

# Update link poses
link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)
ik_configuration.update(mj_data.qpos)

# Get current EEF position (virtual_gripper_keypoint)
kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
current_eef_3d = mj_data.xpos[kp_body_id].copy()
print(f"Current EEF position (3D): ({current_eef_3d[0]:.4f}, {current_eef_3d[1]:.4f}, {current_eef_3d[2]:.4f})")

# Generate waypoint trajectory with arc motion
def generate_arc_waypoints(start_3d, end_3d, num_waypoints, arc_height_offset):
    """Generate waypoints between start and end with an upward arc in the middle.
    
    Args:
        start_3d: (3,) starting 3D position
        end_3d: (3,) ending 3D position
        num_waypoints: Number of waypoints (including start and end)
        arc_height_offset: Height offset to add at the midpoint (meters)
    
    Returns:
        waypoints: (num_waypoints, 3) array of 3D waypoints
    """
    waypoints = []
    for i in range(num_waypoints):
        t = i / (num_waypoints - 1) if num_waypoints > 1 else 0.0  # Parameter from 0 to 1
        
        # Linear interpolation
        waypoint = start_3d + t * (end_3d - start_3d)
        
        # Add arc: parabolic height offset, maximum at t=0.5
        arc_offset = 4 * arc_height_offset * t * (1 - t)  # Parabola peaking at t=0.5
        waypoint[2] += arc_offset  # Add to height (index 2)
        
        waypoints.append(waypoint)
    
    return np.array(waypoints)

# Generate full waypoint trajectory: current → pick → place
print(f"\nGenerating waypoint trajectory with {args.waypoint_density} waypoints per segment, arc_height={args.arc_height}m...")
waypoints_to_pick = generate_arc_waypoints(current_eef_3d, pick_3d, args.waypoint_density, args.arc_height)
waypoints_to_place = generate_arc_waypoints(pick_3d, place_3d, args.waypoint_density, args.arc_height)

# Combine waypoints (exclude duplicate pick_3d at start of second segment)
full_waypoint_trajectory = np.concatenate([waypoints_to_pick, waypoints_to_place[1:]], axis=0)
print(f"✓ Generated {len(full_waypoint_trajectory)} total waypoints")

# Assign gripper states: closed at pick and after, open at place
gripper_states = np.ones(len(full_waypoint_trajectory))  # Default: open
# Close gripper from pick onwards
pick_idx = args.waypoint_density - 1  # Last waypoint of first segment
gripper_states[pick_idx:pick_idx + args.waypoint_density] = 0.0  # Closed from pick to place
# Open gripper at place onwards (already 1.0)

print("\nProjecting waypoints to 2D for visualization...")
waypoints_2d = []
for wp_3d in full_waypoint_trajectory:
    wp_2d = project_3d_to_2d(wp_3d, camera_pose, cam_K)
    if wp_2d is not None:
        # Rescale to loaded image resolution
        wp_2d_loaded = wp_2d * np.array([W_loaded / W_orig, H_loaded / H_orig])
        waypoints_2d.append(wp_2d_loaded)
    else:
        waypoints_2d.append(None)
waypoints_2d = np.array([w for w in waypoints_2d if w is not None])
print(f"✓ Projected {len(waypoints_2d)} waypoints to 2D")

# Run IK for all waypoints
ik_joint_positions = []

print("\nRunning IK on waypoints...")
for i, target_pos in enumerate(full_waypoint_trajectory):
    target_gripper_rot = median_dataset_rotation
    
    # Update IK configuration from current state
    link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
    position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
    mujoco.mj_forward(mj_model, mj_data)
    ik_configuration.update(mj_data.qpos)
    
    # Solve IK with median rotation constraint
    ik_to_keypoint_and_rotation(target_pos, target_gripper_rot, ik_configuration, robot_config, mj_model, mj_data, max_iterations=40)
    
    # Set gripper state
    gripper_val = gripper_states[i]
    if len(mj_data.qpos) > 0:
        mj_data.qpos[-1] = gripper_val
    if len(mj_data.ctrl) > 0:
        mj_data.ctrl[-1] = gripper_val
    
    # Update link poses and forward kinematics after IK
    link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
    position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
    mujoco.mj_forward(mj_model, mj_data)
    
    # Save joint positions
    ik_joint_positions.append(mj_data.qpos.copy())
    
    # Verify IK result (print only key waypoints)
    if i % args.waypoint_density == 0 or i == len(full_waypoint_trajectory) - 1:
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        achieved_kp_pos = mj_data.xpos[kp_body_id].copy()
        ik_error = np.linalg.norm(achieved_kp_pos - target_pos)
        waypoint_type = "Current→Pick" if i < args.waypoint_density else "Pick→Place"
        print(f"  Waypoint {i}/{len(full_waypoint_trajectory)-1} ({waypoint_type}): "
              f"Target=[{target_pos[0]:.4f}, {target_pos[1]:.4f}, {target_pos[2]:.4f}], "
              f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
              f"IK Error={ik_error:.4f}m")

print(f"✓ Completed IK for {len(ik_joint_positions)} waypoints")

# Render robot poses for all waypoints
print("\nRendering robot poses for all waypoints...")
cam_K_for_render = cam_K.copy()

rendered_images = []
waypoint_labels = []

pick_idx = args.waypoint_density - 1
place_idx = len(ik_joint_positions) - 1

for i, joint_pos in enumerate(ik_joint_positions):
    # Set joint positions
    mj_data.qpos[:] = joint_pos
    mj_data.ctrl[:] = joint_pos[:len(mj_data.ctrl)]
    mujoco.mj_forward(mj_model, mj_data)
    
    # Update link poses
    link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
    position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
    mujoco.mj_forward(mj_model, mj_data)
    
    # Render from camera pose
    rendered = render_from_camera_pose(mj_model, mj_data, camera_pose, cam_K_for_render, H_orig, W_orig)
    # Resize to match loaded image
    rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
    rendered_images.append(rendered_resized)
    
    # Generate label
    if i == 0:
        label = "Start"
    elif i == pick_idx:
        label = "Pick"
    elif i == place_idx:
        label = "Place"
    else:
        label = f"t={i}"
    waypoint_labels.append(label)

print(f"✓ Rendered {len(rendered_images)} waypoint poses")

# Display results
num_waypoints = len(rendered_images)
# Calculate grid size: first row is trajectory visualization, rest are overlays
# Use up to 5 columns for better visibility
cols = min(5, num_waypoints)
rows = 1 + (num_waypoints + cols - 1) // cols  # 1 row for trajectory + rows for overlays

fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
if rows == 1:
    axes = axes.reshape(1, -1)

# 1. Original image with waypoint trajectory (first image in first row)
axes[0, 0].imshow(rgb)
# Plot waypoint trajectory
if len(waypoints_2d) > 0:
    axes[0, 0].plot(waypoints_2d[:, 0], waypoints_2d[:, 1], 'g-', linewidth=2, alpha=0.7, label='Waypoint trajectory')
    # Mark current, pick, place with different colors
    current_2d_proj = project_3d_to_2d(current_eef_3d, camera_pose, cam_K)
    if current_2d_proj is not None:
        current_2d_loaded = current_2d_proj * np.array([W_loaded / W_orig, H_loaded / H_orig])
        axes[0, 0].plot(current_2d_loaded[0], current_2d_loaded[1], 'go', markersize=12, label='Start', markeredgecolor='white', markeredgewidth=2)
axes[0, 0].plot(pick_2d[0], pick_2d[1], 'ro', markersize=15, label='Pick', markeredgecolor='white', markeredgewidth=2)
axes[0, 0].plot(place_2d[0], place_2d[1], 'bo', markersize=15, label='Place', markeredgecolor='white', markeredgewidth=2)
if pick_2d_reproj_loaded is not None:
    axes[0, 0].plot(pick_2d_reproj_loaded[0], pick_2d_reproj_loaded[1], 'r+', markersize=20, markeredgewidth=3, alpha=0.7)
if place_2d_reproj_loaded is not None:
    axes[0, 0].plot(place_2d_reproj_loaded[0], place_2d_reproj_loaded[1], 'b+', markersize=20, markeredgewidth=3, alpha=0.7)
axes[0, 0].set_title(f'Waypoint Trajectory ({len(waypoints_2d)} waypoints)\nGreen line = 3D waypoint path projected to 2D', fontsize=10)
axes[0, 0].legend(fontsize=8)
axes[0, 0].axis('off')

# Hide remaining cells in first row
for col_idx in range(1, cols):
    axes[0, col_idx].axis('off')

# 2. Render all waypoint overlays in subsequent rows
for i, (rendered, label) in enumerate(zip(rendered_images, waypoint_labels)):
    row_idx = 1 + i // cols
    col_idx = i % cols
    
    # Create overlay
    overlay = (rgb * 0.5 + rendered * 0.5).astype(np.uint8)
    axes[row_idx, col_idx].imshow(overlay)
    axes[row_idx, col_idx].set_title(f'Overlay: {label}', fontsize=10)
    axes[row_idx, col_idx].axis('off')

# Hide any unused cells in the last row
total_cells_used = num_waypoints
last_row_idx = rows - 1
for col_idx in range(total_cells_used % cols, cols):
    if col_idx > 0 or total_cells_used % cols != 0:
        axes[last_row_idx, col_idx].axis('off')


plt.tight_layout()
plt.savefig('test_pickplace_manual_result.png', dpi=150, bbox_inches='tight')
print("\n✓ Saved result to test_pickplace_manual_result.png")
plt.show()
