"""Live pick and place trajectory planning from camera stream and robot state."""
import sys
import os
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation

import xml.etree.ElementTree as ET
import argparse
import pickle
import mujoco
import mink

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from robot_models.so100_controller import Arm
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot
from utils import project_3d_to_2d, recover_3d_from_direct_keypoint_and_height, ik_to_keypoint_and_rotation
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX

# Parse arguments
parser = argparse.ArgumentParser(description='Live pick and place trajectory planning')
parser.add_argument('--camera', '-c', type=int, default=0, help='Camera device ID (default: 0)')
parser.add_argument('--waypoint_density', type=int, default=5, help='Number of waypoints per segment (default: 5)')
parser.add_argument('--arc_height', type=float, default=0.05, help='Height offset for arc trajectory in meters (default: 0.05m)')
parser.add_argument('--use_saved_pts', action='store_true', help='Use hard-coded pick/place points instead of clicking')
parser.add_argument('--use_camera_pose', action='store_true', help='Use camera pose from exoskeleton detection')
parser.add_argument('--write_robot_state', action='store_true', help='Write robot state to file')
parser.add_argument('--render', action='store_true', help='Render robot poses')
parser.add_argument('--dataset_dir', '-d', type=str, default='scratch/parsed_pp_silverb', help='Dataset directory for median height computation')
args = parser.parse_args()


# Initialize camera
if args.use_camera_pose:
    print(f"Initializing camera device {args.camera}...")
    cap = cv2.VideoCapture(args.camera)
    import time
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open camera device {args.camera}")

    # Capture current frame
    print("Capturing frame from camera...")
    time.sleep(1)
    ret, frame = cap.read()
    if not ret: raise RuntimeError("Failed to capture frame from camera")

    cap.release()
    np.save("pickplace_testing/tmpstorage/rgb.npy", frame)
else:
    frame = np.load("pickplace_testing/tmpstorage/rgb.npy")


rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)

H_loaded, W_loaded = frame.shape[:2]
# Hardcode original image resolution (before downsampling)
# cam_K is calibrated for this resolution
H_orig = 1080
W_orig = 1920

# Setup robot for exoskeleton detection
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Load actual robot state from arm
if args.use_camera_pose:

    # Initialize robot connection
    print("Connecting to robot...")
    arm = Arm(pickle.load(open("robot_models/arm_offsets/rescrew2_fromimg.pkl", 'rb')))
    print("✓ Connected to robot for direct joint state reading")
    print("\nReading current robot state from motors...")
    current_joint_state = arm.get_pos()
    np.save("pickplace_testing/tmpstorage/current_joint_state.npy", current_joint_state)
else:
    current_joint_state = np.load("pickplace_testing/tmpstorage/current_joint_state.npy")
mj_data.qpos[:] = current_joint_state
mujoco.mj_forward(mj_model, mj_data)
print(f"✓ Loaded robot joint state from arm motors")

# Detect exoskeleton and infer camera pose
print("Detecting exoskeleton to infer camera calibration...")
from exo_utils import detect_and_set_link_poses, estimate_robot_state


# Detect link poses from image using exoskeleton
# Signature: detect_and_set_link_poses(rgb, model, data, robot_config, visualize=False, cam_K=None)
link_poses, camera_pose, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
    rgb, 
    mj_model, 
    mj_data,
    robot_config,
    cam_K=None
)
print(f"✓ Detected camera pose from exoskeleton")

# Compute median pick and place heights from dataset
print(f"\nComputing median pick and place heights from dataset: {args.dataset_dir}...")
dataset_dir = Path(args.dataset_dir)
episode_dirs = sorted([d for d in dataset_dir.iterdir() if d.is_dir() and d.name.startswith("episode_")])

kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
pick_heights = []
place_heights = []

for episode_dir in episode_dirs[:20]:  # Sample first 20 episodes
    frame_files = sorted(episode_dir.glob("*.png"))
    if len(frame_files) == 0:
        continue
    
    # Load trajectory from gripper poses
    traj_3d = []
    gripper_values = []
    for frame_file in frame_files:
        frame_str = f"{int(frame_file.stem):06d}"
        gripper_pose_path = episode_dir / f"{frame_str}_gripper_pose.npy"
        joint_state_path = episode_dir / f"{frame_str}.npy"
        
        if not gripper_pose_path.exists() or not joint_state_path.exists():
            continue
            
        gripper_pose = np.load(gripper_pose_path)
        joint_state = np.load(joint_state_path)
        
        kp_3d = gripper_pose[:3, :3] @ kp_local + gripper_pose[:3, 3]
        traj_3d.append(kp_3d)
        gripper_values.append(float(joint_state[-1]))
    
    if len(traj_3d) < 3:
        continue
    
    traj_3d = np.array(traj_3d)
    gripper_values = np.array(gripper_values)
    
    # Find grasp (close) and ungrasp (open after being closed) points
    grasp_idx = None
    ungrasp_idx = None
    
    for i in range(1, len(gripper_values)):
        if gripper_values[i] < 0.5 and gripper_values[i-1] >= 0.5:
            grasp_idx = i
        elif grasp_idx is not None and gripper_values[i] >= 0.5 and gripper_values[i-1] < 0.5:
            ungrasp_idx = i
            break
    
    if grasp_idx is not None:
        pick_heights.append(traj_3d[grasp_idx][2])
    if ungrasp_idx is not None:
        place_heights.append(traj_3d[ungrasp_idx][2])

median_pick_height = float(np.median(pick_heights)) if len(pick_heights) > 0 else 0.12
median_place_height = float(np.median(place_heights)) if len(place_heights) > 0 else 0.16

print(f"✓ Computed median pick height: {median_pick_height:.4f}m")
print(f"✓ Computed median place height: {median_place_height:.4f}m")

# Get pick and place locations from user or use saved points
if args.use_saved_pts:
    # Use hard-coded points from earlier successful run (at loaded image resolution)
    # These need to be scaled based on actual camera resolution
    pick_2d = np.array([1024.5, 530.2])
    place_2d = np.array([1201.0, 304.1])

    print("\nUsing saved pick and place locations (scaled to current resolution):")
    print(f"  Pick location (image coords): ({pick_2d[0]:.1f}, {pick_2d[1]:.1f})")
    print(f"  Place location (image coords): ({place_2d[0]:.1f}, {place_2d[1]:.1f})")
else:
    print("\nClick on the image to select:")
    print("  1. Pick location (first click)")
    print("  2. Place location (second click)")
    print("Close the window when done.")

    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(rgb)
    ax.set_title("Click: 1) Pick location, 2) Place location")
    ax.axis('off')

    # Get clicks
    clicks = plt.ginput(2, timeout=0)
    plt.close()

    if len(clicks) < 2:
        print("Error: Need 2 clicks (pick and place)")
        exit(1)

    pick_2d = np.array([clicks[0][0], clicks[0][1]])
    place_2d = np.array([clicks[1][0], clicks[1][1]])

    print(f"Pick location (image coords): ({pick_2d[0]:.1f}, {pick_2d[1]:.1f})")
    print(f"Place location (image coords): ({place_2d[0]:.1f}, {place_2d[1]:.1f})")
    
    # Check if pick and place are too close
    distance = np.linalg.norm(pick_2d - place_2d)
    if distance < 5.0:
        print(f"\n⚠ Warning: Pick and place locations are very close ({distance:.1f} pixels apart)")
        print("This might be an accidental double-click. Consider re-running.")

# Convert 2D locations to 3D using median heights
# Rescale from loaded image to original resolution
pick_2d_orig = pick_2d * np.array([W_orig / W_loaded, H_orig / H_loaded])
place_2d_orig = place_2d * np.array([W_orig / W_loaded, H_orig / H_loaded])

print(f"\nLifting 2D to 3D:")
print(f"Pick: 2D=({pick_2d_orig[0]:.1f}, {pick_2d_orig[1]:.1f}), height={median_pick_height:.4f}m (median from grasp locations)")
print(f"Place: 2D=({place_2d_orig[0]:.1f}, {place_2d_orig[1]:.1f}), height={median_place_height:.4f}m (median from ungrasp locations)")

pick_3d = recover_3d_from_direct_keypoint_and_height(pick_2d_orig, median_pick_height, camera_pose, cam_K)
place_3d = recover_3d_from_direct_keypoint_and_height(place_2d_orig, median_place_height, camera_pose, cam_K)

if pick_3d is None or place_3d is None:
    print("Error: Failed to recover 3D locations")
    exit(1)

print(f"Pick location (3D): ({pick_3d[0]:.4f}, {pick_3d[1]:.4f}, {pick_3d[2]:.4f}) [using median_pick_height={median_pick_height:.4f}m]")
print(f"Place location (3D): ({place_3d[0]:.4f}, {place_3d[1]:.4f}, {place_3d[2]:.4f}) [using median_place_height={median_place_height:.4f}m]")

# Reprojection check
pick_2d_reproj = project_3d_to_2d(pick_3d, camera_pose, cam_K)
place_2d_reproj = project_3d_to_2d(place_3d, camera_pose, cam_K)

pick_2d_reproj_loaded = None
place_2d_reproj_loaded = None
pick_reproj_error = 0.0
place_reproj_error = 0.0

if pick_2d_reproj is not None:
    pick_2d_reproj_loaded = pick_2d_reproj * np.array([W_loaded / W_orig, H_loaded / H_orig])
    pick_reproj_error = np.linalg.norm(pick_2d - pick_2d_reproj_loaded)

if place_2d_reproj is not None:
    place_2d_reproj_loaded = place_2d_reproj * np.array([W_loaded / W_orig, H_loaded / H_orig])
    place_reproj_error = np.linalg.norm(place_2d - place_2d_reproj_loaded)

print("\nReprojection check:")
print(f"Pick - Original: ({pick_2d[0]:.1f}, {pick_2d[1]:.1f}), Reprojected: ({pick_2d_reproj_loaded[0]:.1f}, {pick_2d_reproj_loaded[1]:.1f}), Error: {pick_reproj_error:.2f} pixels")
print(f"Place - Original: ({place_2d[0]:.1f}, {place_2d[1]:.1f}), Reprojected: ({place_2d_reproj_loaded[0]:.1f}, {place_2d_reproj_loaded[1]:.1f}), Error: {place_reproj_error:.2f} pixels")

if pick_reproj_error > 10.0 or place_reproj_error > 10.0:
    print("⚠ Warning: Large reprojection error detected. Check camera calibration or height estimates.")

# Robot and camera pose already initialized above during exoskeleton detection
# Just need to update forward kinematics after setting joint state
mujoco.mj_forward(mj_model, mj_data)

# Setup MuJoCo visualization (always needed for 2D vis with MuJoCo pane)
ik_configuration = None
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

# Initialize MuJoCo visualization model for IK and rendering
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')
mj_model_viz = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data_viz = mujoco.MjData(mj_model_viz)
ik_configuration = mink.Configuration(mj_model_viz)

# Update link poses
link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)
ik_configuration.update(mj_data.qpos)

# Get current EEF position (virtual_gripper_keypoint)
kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
current_eef_3d = mj_data.xpos[kp_body_id].copy()
print(f"Current EEF position (3D): ({current_eef_3d[0]:.4f}, {current_eef_3d[1]:.4f}, {current_eef_3d[2]:.4f})")

# Generate waypoint trajectory with arc motion
def generate_arc_waypoints(start_3d, end_3d, num_waypoints, arc_height_offset):
    """Generate waypoints between start and end with an upward arc in the middle."""
    waypoints = []
    for i in range(num_waypoints):
        t = i / (num_waypoints - 1) if num_waypoints > 1 else 0.0
        
        # Linear interpolation
        waypoint = start_3d + t * (end_3d - start_3d)
        
        # Add arc: parabolic height offset, maximum at t=0.5
        arc_offset = 4 * arc_height_offset * t * (1 - t)
        waypoint[2] += arc_offset
        
        waypoints.append(waypoint)
    
    return np.array(waypoints)

# Generate full waypoint trajectory: current → pick → place
print(f"\nGenerating waypoint trajectory with {args.waypoint_density} waypoints per segment, arc_height={args.arc_height}m...")
waypoints_to_pick = generate_arc_waypoints(current_eef_3d, pick_3d, args.waypoint_density, args.arc_height)
waypoints_to_place = generate_arc_waypoints(pick_3d, place_3d, args.waypoint_density, args.arc_height)

# Combine waypoints (exclude duplicate pick_3d at start of second segment)
full_waypoint_trajectory = np.concatenate([waypoints_to_pick, waypoints_to_place[1:]], axis=0)
print(f"✓ Generated {len(full_waypoint_trajectory)} total waypoints")

# Assign gripper states: closed at pick and after, open at place
gripper_states = np.ones(len(full_waypoint_trajectory))  # Default: open
# Close gripper from pick onwards
pick_idx = args.waypoint_density - 1
gripper_states[pick_idx:pick_idx + args.waypoint_density] = 0.0  # Closed from pick to place

print("\nProjecting waypoints to 2D for visualization...")
waypoints_2d = []
for wp_3d in full_waypoint_trajectory:
    wp_2d = project_3d_to_2d(wp_3d, camera_pose, cam_K)
    if wp_2d is not None:
        wp_2d_loaded = wp_2d * np.array([W_loaded / W_orig, H_loaded / H_orig])
        waypoints_2d.append(wp_2d_loaded)
    else:
        waypoints_2d.append(None)
waypoints_2d = np.array([w for w in waypoints_2d if w is not None])
print(f"✓ Projected {len(waypoints_2d)} waypoints to 2D")

# Run IK for all waypoints
ik_joint_positions = []

print("\nRunning IK on waypoints...")
#mj_data.qpos[:] = current_joint_state
#mj_data.ctrl[:] = current_joint_state[:len(mj_data.ctrl)]
#mujoco.mj_forward(mj_model, mj_data)
# Update IK configuration from current state
link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)
ik_configuration.update(mj_data.qpos)

for i, target_pos in enumerate(full_waypoint_trajectory):

    target_gripper_rot = median_dataset_rotation
    
    # Update IK configuration from current state before solving
    link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
    position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
    mujoco.mj_forward(mj_model_viz, mj_data_viz)
    ik_configuration.update(mj_data_viz.qpos)
    
    # Solve IK with median rotation constraint
    ik_to_keypoint_and_rotation(target_pos, target_gripper_rot, ik_configuration, robot_config, mj_model_viz, mj_data_viz,max_iterations=30)
    
    # Update link poses and forward kinematics after IK
    link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
    position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
    mujoco.mj_forward(mj_model_viz, mj_data_viz)
    
    # Save joint positions after IK (including gripper)
    ik_joint_positions.append(mj_data_viz.qpos.copy())
    
    # Verify IK result (print only key waypoints)
    kp_body_id = mujoco.mj_name2id(mj_model_viz, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
    achieved_kp_pos = mj_data_viz.xpos[kp_body_id].copy()
    ik_error = np.linalg.norm(achieved_kp_pos - target_pos)
    waypoint_type = "Current→Pick" if i < args.waypoint_density else "Pick→Place"
    print(f"  Waypoint {i}/{len(full_waypoint_trajectory)-1} ({waypoint_type}): "
            f"Target=[{target_pos[0]:.4f}, {target_pos[1]:.4f}, {target_pos[2]:.4f}], "
            f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
            f"IK Error={ik_error:.4f}m")
ik_joint_positions.append(current_joint_state)

print(f"✓ Completed IK for {len(ik_joint_positions)} waypoints")

# Render robot poses for all waypoints
print("\nRendering robot poses for all waypoints...")
cam_K_for_render = cam_K.copy()
cam_K_for_render[:2]/=10

rendered_images = []
waypoint_labels = []

pick_idx = args.waypoint_density - 1
place_idx = len(ik_joint_positions) - 1

if args.render:
    for i, joint_pos in enumerate(ik_joint_positions):
        # Set joint positions
        mj_data.qpos[:] = joint_pos
        mj_data.ctrl[:] = joint_pos[:len(mj_data.ctrl)]
        mujoco.mj_forward(mj_model, mj_data)
        
        # Update link poses
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        
        # Render from camera pose
        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose, cam_K_for_render, H_orig//10, W_orig//10)
        # Resize to match loaded image
        rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        rendered_images.append(rendered_resized)
        
        # Generate label
        if i == 0:
            label = "Start"
        elif i == pick_idx:
            label = "Pick"
        elif i == place_idx:
            label = "Place"
        else:
            label = f"t={i}"
        waypoint_labels.append(label)

    print(f"✓ Rendered {len(rendered_images)} waypoint poses")

    # Display results
    num_waypoints = len(rendered_images)
    cols = min(5, num_waypoints)
    rows = 1 + (num_waypoints + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    # 1. Original image with waypoint trajectory (first image in first row)
    axes[0, 0].imshow(rgb)
    # Plot waypoint trajectory
    if len(waypoints_2d) > 0:
        axes[0, 0].plot(waypoints_2d[:, 0], waypoints_2d[:, 1], 'g-', linewidth=2, alpha=0.7, label='Waypoint trajectory')
        # Mark current, pick, place with different colors
        current_2d_proj = project_3d_to_2d(current_eef_3d, camera_pose, cam_K)
        if current_2d_proj is not None:
            current_2d_loaded = current_2d_proj * np.array([W_loaded / W_orig, H_loaded / H_orig])
            axes[0, 0].plot(current_2d_loaded[0], current_2d_loaded[1], 'go', markersize=12, label='Start', markeredgecolor='white', markeredgewidth=2)
    axes[0, 0].plot(pick_2d[0], pick_2d[1], 'ro', markersize=15, label='Pick', markeredgecolor='white', markeredgewidth=2)
    axes[0, 0].plot(place_2d[0], place_2d[1], 'bo', markersize=15, label='Place', markeredgecolor='white', markeredgewidth=2)
    if pick_2d_reproj_loaded is not None:
        axes[0, 0].plot(pick_2d_reproj_loaded[0], pick_2d_reproj_loaded[1], 'r+', markersize=20, markeredgewidth=3, alpha=0.7)
    if place_2d_reproj_loaded is not None:
        axes[0, 0].plot(place_2d_reproj_loaded[0], place_2d_reproj_loaded[1], 'b+', markersize=20, markeredgewidth=3, alpha=0.7)
    axes[0, 0].set_title(f'Waypoint Trajectory ({len(waypoints_2d)} waypoints)\nGreen line = 3D waypoint path projected to 2D', fontsize=10)
    axes[0, 0].legend(fontsize=8)
    axes[0, 0].axis('off')

    # Hide remaining cells in first row
    for col_idx in range(1, cols):
        axes[0, col_idx].axis('off')

    # 2. Render all waypoint overlays in subsequent rows
    for i, (rendered, label) in enumerate(zip(rendered_images, waypoint_labels)):
        row_idx = 1 + i // cols
        col_idx = i % cols
        
        # Create overlay
        overlay = (rgb * 0.5 + rendered * 0.5).astype(np.uint8)
        axes[row_idx, col_idx].imshow(overlay)
        axes[row_idx, col_idx].set_title(f'Overlay: {label}', fontsize=10)
        axes[row_idx, col_idx].axis('off')

    # Hide any unused cells in the last row
    total_cells_used = num_waypoints
    last_row_idx = rows - 1
    for col_idx in range(total_cells_used % cols, cols):
        if col_idx > 0 or total_cells_used % cols != 0:
            axes[last_row_idx, col_idx].axis('off')

    plt.tight_layout()
    plt.savefig('live_test_pickplace_manual_result.png', dpi=150, bbox_inches='tight')
    print("\n✓ Saved result to live_test_pickplace_manual_result.png")
    plt.show()

if args.write_robot_state:
    targ_gripper_pos=1
    for i, targ_pos in enumerate(ik_joint_positions[:-1]):
        print(targ_pos, )
        if i<=pick_idx: targ_pos[-1]=1
        elif i<place_idx: targ_pos[-1]=-.2
        else: targ_pos[-1]=1
        last_pos=arm.get_pos()
        arm.write_pos(targ_pos,slow=False)
        while True: # keep writing until the position is reached
            curr_pos=arm.get_pos()
            curr_delta=np.max(np.abs(curr_pos-last_pos))
            if curr_delta<0.02: break
            last_pos=curr_pos
        if i==pick_idx or i==place_idx:
            targ_pos[-1]=-.2 if i==pick_idx else 1
            last_pos=arm.get_pos()
            arm.write_pos(targ_pos,slow=False)
            while True: # keep writing until the position is reached
                curr_pos=arm.get_pos()
                curr_delta=np.max(np.abs(curr_pos-last_pos))
                if curr_delta<0.01: break
                last_pos=curr_pos
    targ_pos[-1]=1
    arm.write_pos(targ_pos,slow=False)
    time.sleep(1)


print("\n✓ Done!")
