"""Test 3D lifting: Load GT groundplane trajectory and height, visualize in 2D, then verify 3D lifting."""
import sys
import os
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX
from utils import project_3d_to_2d, rescale_coords, recover_3d_from_keypoint_and_height
from model import MIN_HEIGHT, MAX_HEIGHT

# Constants
RES_LOW = 256
H_orig = 1080
W_orig = 1920

import argparse
parser = argparse.ArgumentParser(description="Test 3D lifting with GT groundplane and height trajectories")
parser.add_argument("--dataset_dir", "-d", default="scratch/parsed_judy_train", type=str, help="Dataset directory")
parser.add_argument("--episode_id", "-e", default=1, type=int, help="Episode ID")
parser.add_argument("--start_frame", "--sf", default=0, type=int, help="Start frame index")
args = parser.parse_args()

dataset_dir = Path(args.dataset_dir)
episode_dir = Path(f"{dataset_dir}/episode_{args.episode_id:03d}")
episode_id = episode_dir.name

# Find all frame files
frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
if len(frame_files) == 0:
    print(f"No frames found in {episode_dir}")
    exit(1)

start_idx = args.start_frame
if start_idx >= len(frame_files):
    print(f"Start frame {start_idx} out of range (max: {len(frame_files)-1})")
    exit(1)

# Load start frame
start_frame_file = frame_files[start_idx]
start_frame_str = f"{int(start_frame_file.stem):06d}"

# Load RGB image
rgb = cv2.cvtColor(cv2.imread(str(start_frame_file)), cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)
H_loaded, W_loaded = rgb.shape[:2]

# Resize RGB to original resolution for visualization consistency
rgb_resized = cv2.resize(rgb, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
rgb_lowres = cv2.resize(rgb, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Load camera pose and intrinsics
camera_pose_path = episode_dir / f"{start_frame_str}_camera_pose.npy"
cam_K_path = episode_dir / f"{start_frame_str}_cam_K.npy"
if not camera_pose_path.exists() or not cam_K_path.exists():
    print(f"Missing camera pose or intrinsics for frame {start_frame_str}")
    exit(1)

camera_pose = np.load(camera_pose_path)
cam_K = np.load(cam_K_path)

# Load GT trajectory 3D keypoints
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
trajectory_gt_3d = []
for frame_file in frame_files[start_idx + 1:]:  # All frames after start
    frame_idx_file = int(frame_file.stem)
    frame_str_file = f"{frame_idx_file:06d}"
    pose_path = episode_dir / f"{frame_str_file}_gripper_pose.npy"
    if not pose_path.exists():
        continue
    pose = np.load(pose_path)
    rot = pose[:3, :3]
    pos = pose[:3, 3]
    kp_3d = rot @ kp_local + pos
    trajectory_gt_3d.append(kp_3d)

if len(trajectory_gt_3d) == 0:
    print("No valid trajectory points found")
    exit(1)

trajectory_gt_3d = np.array(trajectory_gt_3d)

# Project to groundplane (set Y=0, where Y is index 2)
trajectory_groundplane_3d = trajectory_gt_3d.copy()
trajectory_groundplane_3d[:, 2] = 0.0

# Project groundplane trajectory to 2D image coordinates
trajectory_points_orig = []
heights_gt = []
for i, kp_3d_gp in enumerate(trajectory_groundplane_3d):
    kp_2d_image = project_3d_to_2d(kp_3d_gp, camera_pose, cam_K)
    if kp_2d_image is not None:
        trajectory_points_orig.append(kp_2d_image)
        # Height from original 3D keypoint (before groundplane projection)
        kp_3d_orig = trajectory_gt_3d[i]
        height_norm = np.clip((kp_3d_orig[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0)
        heights_gt.append(height_norm)

trajectory_points_orig = np.array(trajectory_points_orig) if len(trajectory_points_orig) > 0 else None
heights_gt = np.array(heights_gt) if len(heights_gt) > 0 else None

if trajectory_points_orig is None or len(trajectory_points_orig) == 0:
    print("No valid trajectory points found")
    exit(1)

# Denormalize heights
heights_gt_denorm = heights_gt * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT

# Rescale trajectory to low-res for visualization
trajectory_points_lowres = rescale_coords(trajectory_points_orig, H_orig, W_orig, RES_LOW, RES_LOW)

# Test 3D lifting: lift 2D groundplane + height back to 3D
trajectory_lifted_3d = []
for i in range(len(trajectory_points_orig)):
    kp_2d = trajectory_points_orig[i]
    height = heights_gt_denorm[i]
    kp_3d_lifted = recover_3d_from_keypoint_and_height(kp_2d, height, camera_pose, cam_K)
    if kp_3d_lifted is not None:
        trajectory_lifted_3d.append(kp_3d_lifted)

trajectory_lifted_3d = np.array(trajectory_lifted_3d) if len(trajectory_lifted_3d) > 0 else None

# Compute lifting errors
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    lifting_errors = np.linalg.norm(trajectory_lifted_3d - trajectory_gt_3d[:len(trajectory_lifted_3d)], axis=1)
    print(f"3D Lifting Verification:")
    print(f"  Mean error: {np.mean(lifting_errors):.6f}m")
    print(f"  Max error: {np.max(lifting_errors):.6f}m")
    print(f"  Min error: {np.min(lifting_errors):.6f}m")
    if len(lifting_errors) > 0:
        print(f"  First point error: {lifting_errors[0]:.6f}m")
        print(f"    GT 3D: {trajectory_gt_3d[0]}")
        print(f"    Lifted 3D: {trajectory_lifted_3d[0]}")
        print(f"    GT 2D groundplane: {trajectory_points_orig[0]}")
        print(f"    GT height: {heights_gt_denorm[0]:.4f}m")

# Create dense groundplane coordinate map (like render_dense_groundplane.py)
print("Computing dense groundplane coordinate map...")
u, v = np.meshgrid(np.arange(W_orig), np.arange(H_orig))
u_flat = u.flatten()
v_flat = v.flatten()

# Convert pixels to normalized camera coordinates (vectorized)
K_inv = np.linalg.inv(cam_K)
pixels_h = np.stack([u_flat, v_flat, np.ones(len(u_flat))], axis=1).T  # (3, N)
rays_cam = (K_inv @ pixels_h).T  # (N, 3)

# Transform rays to world coordinates
cam_pose_inv = np.linalg.inv(camera_pose)
rays_world = (cam_pose_inv[:3, :3] @ rays_cam.T).T  # (N, 3) - direction vectors
cam_pos = cam_pose_inv[:3, 3]  # (3,)

# Find intersection with ground plane (Y = 0, where Y is index 2)
ray_dir_z = rays_world[:, 2]  # Z component of ray direction (Y is index 2)
t = (0.0 - cam_pos[2]) / (ray_dir_z + 1e-10)  # Avoid division by zero

# Valid intersections: t > 0 and not parallel to plane
valid_mask_flat = (t > 0) & (np.abs(ray_dir_z) > 1e-6)

# Compute 3D points on ground plane
points_3d_flat = cam_pos[None, :] + t[:, None] * rays_world  # (N, 3)

# Reshape to image dimensions
points_3d = points_3d_flat.reshape(H_orig, W_orig, 3)
valid_mask = valid_mask_flat.reshape(H_orig, W_orig)

# Extract X and Z coordinates
x_coords = points_3d[:, :, 0]  # X coordinate
z_coords = points_3d[:, :, 1]  # Z coordinate

# Normalize coordinates to [0, 1] for color mapping
grid_range = 1.0
x_norm = np.clip((x_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)
z_norm = np.clip((z_coords + grid_range) / (2 * grid_range + 1e-6), 0, 1)

# Create RGB colors based on XZ coordinates
red_channel = 1.0 - x_norm * 0.5
green_channel_x = x_norm
blue_channel = 1.0 - z_norm
green_channel_z = z_norm

coord_colors = np.stack([
    red_channel,
    np.clip(green_channel_x + green_channel_z, 0, 1),
    blue_channel
], axis=2)  # (H_orig, W_orig, 3)

# Set invalid pixels to white
coord_colors[~valid_mask] = 1.0

# Resize coord_colors to low-res for visualization
coord_colors_lowres = cv2.resize(coord_colors, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)

# Create visualization
fig = plt.figure(figsize=(20, 12))
gs = GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3, height_ratios=[1, 0.4])

# Row 1: RGB with trajectory, Dense groundplane map, Groundplane trajectory overlay
ax_rgb = fig.add_subplot(gs[0, 0])
ax_groundplane = fig.add_subplot(gs[0, 1])
ax_overlay = fig.add_subplot(gs[0, 2])
ax_height = fig.add_subplot(gs[1, :])  # Height chart spans all columns

# RGB with groundplane trajectory
ax_rgb.imshow(rgb_lowres)
if len(trajectory_points_lowres) > 0:
    ax_rgb.plot(trajectory_points_lowres[:, 0], trajectory_points_lowres[:, 1], 
               'w-', linewidth=2, alpha=0.8, label='GT Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(trajectory_points_lowres):
        color = plt.cm.viridis(i / len(trajectory_points_lowres))
        ax_rgb.plot(x, y, 'o', color=color, markersize=6, markeredgecolor='white', 
                   markeredgewidth=1, zorder=11)
ax_rgb.set_title("RGB with Groundplane Trajectory", fontsize=12)
ax_rgb.legend(loc='upper right', fontsize=9)
ax_rgb.axis('off')

# Dense groundplane coordinate map
ax_groundplane.imshow(coord_colors_lowres)
ax_groundplane.set_title("Dense Groundplane Coordinate Map", fontsize=12)
ax_groundplane.axis('off')

# Overlay: RGB + groundplane coordinates + trajectory
rgb_normalized = rgb_lowres.astype(np.float32) / 255.0
overlay = 0.5 * rgb_normalized + 0.5 * coord_colors_lowres
ax_overlay.imshow(overlay)
if len(trajectory_points_lowres) > 0:
    ax_overlay.plot(trajectory_points_lowres[:, 0], trajectory_points_lowres[:, 1], 
                   'w-', linewidth=3, alpha=0.9, label='GT Groundplane Traj', zorder=10)
    for i, (x, y) in enumerate(trajectory_points_lowres):
        color = plt.cm.viridis(i / len(trajectory_points_lowres))
        ax_overlay.plot(x, y, 'o', color=color, markersize=8, markeredgecolor='white', 
                       markeredgewidth=2, zorder=11)
ax_overlay.set_title("Overlay: RGB + Groundplane + Trajectory", fontsize=12)
ax_overlay.legend(loc='upper right', fontsize=9)
ax_overlay.axis('off')

# Height chart
timesteps = np.arange(1, len(heights_gt_denorm) + 1)
ax_height.bar(timesteps, heights_gt_denorm, width=0.6, alpha=0.7, color='green', label='GT Height')
ax_height.set_xlabel('Timestep', fontsize=10)
ax_height.set_ylabel('Height (m)', fontsize=10)
ax_height.set_title('GT Height Trajectory', fontsize=12)
ax_height.legend(fontsize=9)
ax_height.grid(alpha=0.3)

# Add text summary
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    summary_text = f"3D Lifting: Mean Error = {np.mean(lifting_errors):.6f}m"
    fig.text(0.5, 0.02, summary_text, ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
output_path = Path(f'keypoint_testing2/test_scripts/test_ik_lifting_{episode_id}_frame{start_idx}.png')
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"✓ Saved 2D visualization to {output_path}")
plt.show()

print("✓ 2D Visualization complete!")

# MuJoCo visualization (from test_ik.py)
print("\nSetting up MuJoCo visualization...")
import mujoco
import xml.etree.ElementTree as ET
from exo_utils import render_from_camera_pose
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, get_link_poses_from_robot
import mink

# Add GT and lifted trajectory spheres to XML for comparison
robot_config = SO100AdhesiveConfig()
xml_root = ET.fromstring(robot_config.xml)
worldbody = xml_root.find('worldbody')

# GT trajectory spheres (green to blue)
for i, kp_pos in enumerate(trajectory_gt_3d):
    green = 1.0 - (i / max(len(trajectory_gt_3d) - 1, 1))
    blue = i / max(len(trajectory_gt_3d) - 1, 1)
    #ET.SubElement(worldbody, 'site', {
    #    'name': f'gt_kp_{i}', 'type': 'sphere', 'size': '0.015',
    #    'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'0 {green} {blue} 0.8'
    #})

# Lifted trajectory spheres (red to yellow) - from 2D groundplane + height
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    for i, kp_pos in enumerate(trajectory_lifted_3d):
        red = 1.0 - (i / max(len(trajectory_lifted_3d) - 1, 1)) * 0.5
        green = i / max(len(trajectory_lifted_3d) - 1, 1) * 0.5
        ET.SubElement(worldbody, 'site', {
            'name': f'lifted_kp_{i}', 'type': 'sphere', 'size': '0.015',
            'pos': f'{kp_pos[0]} {kp_pos[1]} {kp_pos[2]}', 'rgba': f'{red} {green} 0 0.8'
        })

mj_model = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
mj_data = mujoco.MjData(mj_model)

# Estimate robot state from image
print("Estimating robot state from image...")
link_poses, camera_pose_world, cam_K_estimated, _, _, _ = detect_and_set_link_poses(rgb_resized, mj_model, mj_data, robot_config)
configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
mj_data.qpos[:] = configuration.q
mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
mujoco.mj_forward(mj_model, mj_data)
position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
mujoco.mj_forward(mj_model, mj_data)

# Use saved camera pose and intrinsics (more accurate than estimated)
camera_pose_world = camera_pose
cam_K_for_render = cam_K

# Setup IK configuration
ik_configuration = mink.Configuration(mj_model)
ik_configuration.update(mj_data.qpos)

# IK function to follow keypoint trajectory
def ik_to_keypoint(target_pos, configuration, robot_config, mj_model, mj_data):
    for _ in range(50):
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        mujoco.mj_forward(mj_model, mj_data)
        configuration.update(mj_data.qpos)
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        kp_rot = R.from_quat(mj_data.xquat[kp_body_id][[1, 2, 3, 0]]).as_matrix()
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.1)
        target_quat = R.from_matrix(kp_rot).as_quat()
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([[target_quat[3], target_quat[0], target_quat[1], target_quat[2]], target_pos])))
        posture_task = mink.PostureTask(mj_model, cost=1e-3)
        posture_task.set_target(mj_data.qpos)
        vel = mink.solve_ik(configuration, [kp_task, posture_task], 0.01, "daqp", limits=[mink.ConfigurationLimit(model=mj_model)])
        configuration.integrate_inplace(vel, 0.01)
        mj_data.qpos[:] = configuration.q
        mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
        mujoco.mj_step(mj_model, mj_data)

# Render robot through LIFTED trajectory (from 2D groundplane + height)
if trajectory_lifted_3d is not None and len(trajectory_lifted_3d) > 0:
    print(f"Rendering robot through LIFTED trajectory ({len(trajectory_lifted_3d)} points from 2D groundplane + height)...")
    for i, target_pos in enumerate(trajectory_lifted_3d):
        ik_to_keypoint(target_pos, ik_configuration, robot_config, mj_model, mj_data)
        
        # Verify IK result
        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "virtual_gripper_keypoint")
        achieved_kp_pos = mj_data.xpos[kp_body_id].copy()
        ik_error = np.linalg.norm(achieved_kp_pos - target_pos)
        if i < 3:  # Print first few for debugging
            print(f"  t={i+1}: Target=[{target_pos[0]:.4f}, {target_pos[1]:.4f}, {target_pos[2]:.4f}], "
                  f"Achieved=[{achieved_kp_pos[0]:.4f}, {achieved_kp_pos[1]:.4f}, {achieved_kp_pos[2]:.4f}], "
                  f"IK Error={ik_error:.4f}m")
        
        # Render at original resolution with cam_K (calibrated for H_orig x W_orig)
        rendered = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K_for_render, H_orig, W_orig)
        # Resize rendered to match loaded image size for display
        rendered_resized = cv2.resize(rendered, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR)
        rgb_display = cv2.resize(rgb_resized, (W_loaded, H_loaded), interpolation=cv2.INTER_LINEAR) if rgb_resized.shape[:2] != (H_loaded, W_loaded) else rgb_resized
        
        # Display results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        for ax, img in zip(axes, [rgb_display, rendered_resized, (rgb_display * 0.5 + rendered_resized * 0.5).astype(np.uint8)]):
            ax.imshow(img)
            ax.axis('off')
        axes[0].set_title('Original RGB')
        axes[1].set_title(f'Rendered Robot (Lifted t={i+1}/{len(trajectory_lifted_3d)})')
        axes[2].set_title('Overlay')
        plt.tight_layout()
        plt.show()
else:
    print("⚠ No lifted trajectory available for MuJoCo rendering")

print("✓ MuJoCo visualization complete!")
