"""Test trajectory predictions with 3D lifting and IK visualization."""
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
import mujoco
import mink
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

from data import N_WINDOW, project_3d_to_2d
from model import TrajectoryHeatmapPredictor
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import (
    get_link_poses_from_robot, 
    position_exoskeleton_meshes, 
    render_from_camera_pose,
    detect_and_set_link_poses,
    estimate_robot_state
)

# Configuration
IMAGE_SIZE = 448
CHECKPOINT_PATH = "real_dino_tracks/checkpoints/real_tracks/best.pth"

# Virtual gripper keypoint in local gripper frame (from dataset generation)
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Hard-coded grasp parameters from make_datasets.py
cube_to_gripper_offset = np.array([-0.04433927, -0.00871351, 0.01834697])
initial_kp_quat_wxyz = np.array([-0.01718116, -0.00338855, 0.76436812, 0.64454224])
start_pos = np.array([0, -1.57, 1.57, 1.57, 1.57, 1])


def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to camera ray in world coordinates.
    
    Args:
        point_2d: (2,) 2D pixel coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        cam_pos: (3,) camera position in world frame
        ray_direction: (3,) normalized ray direction in world frame
    """
    # Camera position in world frame
    cam_pos_world = np.linalg.inv(camera_pose)[:3, 3]
    
    # Unproject pixel to camera frame
    K_inv = np.linalg.inv(cam_K)
    point_2d_h = np.array([point_2d[0], point_2d[1], 1.0])
    ray_cam = K_inv @ point_2d_h
    
    # Transform ray to world frame
    cam_rot_world = np.linalg.inv(camera_pose)[:3, :3]
    ray_world = cam_rot_world @ ray_cam
    ray_world = ray_world / np.linalg.norm(ray_world)  # Normalize
    
    return cam_pos_world, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D projection and height.
    
    Args:
        kp_2d_image: (2,) 2D pixel coordinates
        height: height (z-coordinate in world frame)
        camera_pose: (4, 4) camera pose matrix
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (3,) 3D keypoint position, or None if invalid
    """
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (z coordinate)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D point
    point_3d = cam_pos + t * ray_direction
    
    return point_3d


def ik_to_cube_grasp(model, data, configuration, target_pos, target_kp_quat_wxyz, num_iterations=50):
    """Run IK to move virtual gripper keypoint to maintain offset from cube.
    Returns:
        optimized_qpos: (nq,) array of optimized joint positions
        final_error: scalar error in meters
    """
    
    dt = 0.01  # Control timestep
    damping = 1.0  # Damping for smooth motion

    # Run IK iterations
    for iteration in range(num_iterations):
        mujoco.mj_forward(model, data)
        configuration.update(data.qpos)
        
        # Task 1: Keypoint position and orientation (virtual_gripper_keypoint)
        # Use saved initial orientation as target to maintain reference grasp orientation
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.2)
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([target_kp_quat_wxyz, target_pos])))
        
        # Task 2: Posture task to regularize
        posture_task = mink.PostureTask(model, cost=1e-3)
        posture_task.set_target(data.qpos)
        
        # Solve IK
        vel = mink.solve_ik(configuration, [kp_task, posture_task], dt, "daqp", limits=[mink.ConfigurationLimit(model=model)])
        configuration.integrate_inplace(vel, dt)
        data.qpos[:] = configuration.q
        
        # Update forward kinematics
        mujoco.mj_forward(model, data)
    
    # Compute final error

    kp_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
    current_kp_pos = data.xpos[kp_body_id]
    pos_error = target_pos - current_kp_pos
    error_norm = np.linalg.norm(pos_error)
    
    return data.qpos.copy(), error_norm

def preprocess_image(rgb, image_size=IMAGE_SIZE):
    """Preprocess RGB image for model input.
    
    Args:
        rgb: (H, W, 3) RGB image in [0, 255] or [0, 1]
        image_size: Target image size
    
    Returns:
        rgb_tensor: (3, H, W) normalized tensor
        rgb_vis: (H, W, 3) visualization image in [0, 1]
    """
    # Convert to float if needed
    if rgb.max() > 1.0:
        rgb = rgb.astype(np.float32) / 255.0
    
    H_orig, W_orig = rgb.shape[:2]
    
    # Resize to target size
    if H_orig != image_size or W_orig != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    
    # Convert to tensor and normalize for DINOv2
    rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()  # (3, H, W)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_tensor = (rgb_tensor - mean) / std
    
    # For visualization, denormalize
    rgb_vis = rgb.copy()
    
    return rgb_tensor, rgb_vis, H_orig, W_orig


def main():
    parser = argparse.ArgumentParser(description='Test trajectory predictions with IK from raw image and joint state')
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT_PATH,
                       help='Path to model checkpoint')
    parser.add_argument('--image_path', type=str, required=True,
                       help='Path to raw high-resolution RGB image')
    parser.add_argument('--joint_state_path', type=str, required=True,
                       help='Path to joint state .npy file')
    parser.add_argument('--camera_pose_path', type=str, default=None,
                       help='Path to saved camera pose .npy file (if None, will detect from image)')
    parser.add_argument('--cam_K_path', type=str, default=None,
                       help='Path to saved normalized intrinsics .npy file (if None, will detect from image)')
    parser.add_argument('--num_ik_iters', type=int, default=50,
                       help='Number of IK iterations')
    parser.add_argument('--exo', type=str, default="so100_adhesive", 
                       choices=list(EXOSKELETON_CONFIGS.keys()), 
                       help="Exoskeleton configuration to use")
    args = parser.parse_args()
    
    # Setup device
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    # Load MIN_HEIGHT and MAX_HEIGHT from checkpoint if available
    if 'min_height' in checkpoint and 'max_height' in checkpoint:
        import model as model_module
        model_module.MIN_HEIGHT = checkpoint['min_height']
        model_module.MAX_HEIGHT = checkpoint['max_height']
        print(f"✓ Loaded height range from checkpoint: [{checkpoint['min_height']:.6f}, {checkpoint['max_height']:.6f}] m")
        print(f"  ({checkpoint['min_height']*1000:.2f}mm to {checkpoint['max_height']*1000:.2f}mm)")
    print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
    
    # Load raw high-resolution image
    print(f"\nLoading raw image from {args.image_path}...")
    rgb_raw = plt.imread(args.image_path)
    if rgb_raw.shape[2] == 4:  # RGBA
        rgb_raw = rgb_raw[:, :, :3]
    rgb_raw = rgb_raw[..., :3]  # Ensure RGB
    H_orig, W_orig = rgb_raw.shape[:2]
    print(f"✓ Loaded image: {W_orig}x{H_orig}")
    
    # Load joint state
    print(f"\nLoading joint state from {args.joint_state_path}...")
    joint_state = np.load(args.joint_state_path)
    print(f"✓ Loaded joint state: {joint_state}")
    
    # Setup MuJoCo for camera pose and intrinsics detection
    print("\nSetting up MuJoCo...")
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)
    
    # Set robot to joint state
    mj_data.qpos[:len(mj_data.ctrl)] = joint_state[:len(mj_data.ctrl)]
    mj_data.ctrl[:] = joint_state[:len(mj_data.ctrl)]
    mujoco.mj_forward(mj_model, mj_data)
    
    # Try to load saved camera pose and intrinsics from same directory as image (like dataset)
    image_path = Path(args.image_path)
    image_stem = image_path.stem
    
    # Auto-detect paths if not provided
    if args.camera_pose_path is None:
        args.camera_pose_path = image_path.parent / f"{image_stem}_camera_pose.npy"
    if args.cam_K_path is None:
        args.cam_K_path = image_path.parent / f"{image_stem}_cam_K.npy"
    
    # Load or detect camera pose and intrinsics
    if Path(args.camera_pose_path).exists() and Path(args.cam_K_path).exists():
        # Use saved camera pose and intrinsics (like test_model_ik.py uses from dataset)
        print("\nLoading saved camera pose and intrinsics from dataset files...")
        camera_pose_world = np.load(args.camera_pose_path)
        cam_K_norm_saved = np.load(args.cam_K_path)
        
        # Denormalize to original resolution to get cam_K (for projection)
        cam_K = cam_K_norm_saved.copy()
        cam_K[0] *= W_orig
        cam_K[1] *= H_orig
        
        # Position exoskeleton meshes for visualization
        link_poses = get_link_poses_from_robot(robot_config, mj_model, mj_data)
        position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
        print("✓ Loaded saved camera pose and intrinsics (matching test_model_ik.py)")
        print(f"  Using saved cam_K_norm from dataset")
    else:
        # Detect camera pose and intrinsics (like simple_dataset_record.py)
        print("\nDetecting camera pose and intrinsics from image...")
        cam_K = None
        try:
            link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                rgb_raw, mj_model, mj_data, robot_config, cam_K=cam_K
            )
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            print("✓ Detected camera pose and intrinsics")
        except Exception as e:
            print(f"✗ Error detecting camera pose: {e}")
            raise
        
        if camera_pose_world is None or cam_K is None:
            raise RuntimeError("Failed to detect camera pose or intrinsics")
        
        # Normalize detected intrinsics for consistency
        cam_K_norm_saved = cam_K.copy()
        cam_K_norm_saved[0] /= W_orig
        cam_K_norm_saved[1] /= H_orig
    
    print(f"  Camera intrinsics (at {W_orig}x{H_orig}): fx={cam_K[0,0]:.2f}, fy={cam_K[1,1]:.2f}, cx={cam_K[0,2]:.2f}, cy={cam_K[1,2]:.2f}")
    
    # Preprocess image for model
    rgb_tensor, rgb_vis_resized, H_orig_processed, W_orig_processed = preprocess_image(rgb_raw.copy(), IMAGE_SIZE)
    
    # Verify dimensions match
    assert H_orig_processed == H_orig and W_orig_processed == W_orig, "Image dimensions mismatch"
    
    # Get starting keypoint 2D position from current gripper pose
    fixed_gripper_pose = get_link_poses_from_robot(robot_config, mj_model, mj_data)["fixed_gripper"]
    gripper_rot = fixed_gripper_pose[:3, :3]
    gripper_pos = fixed_gripper_pose[:3, 3]
    kp_3d_start = gripper_rot @ kp_local + gripper_pos
    
    # Project starting keypoint to 2D using original image resolution
    kp_2d_start = project_3d_to_2d(kp_3d_start, camera_pose_world, cam_K)
    if kp_2d_start is None:
        raise RuntimeError("Starting keypoint is behind camera")
    
    # Scale to model input resolution
    scale_x = IMAGE_SIZE / W_orig
    scale_y = IMAGE_SIZE / H_orig
    start_keypoint_2d = kp_2d_start * np.array([scale_x, scale_y])
    
    print(f"✓ Starting keypoint 2D (original): [{kp_2d_start[0]:.1f}, {kp_2d_start[1]:.1f}] px")
    print(f"✓ Starting keypoint 2D (model res): [{start_keypoint_2d[0]:.1f}, {start_keypoint_2d[1]:.1f}] px")
    
    # Get current height from starting keypoint 3D position (z-coordinate)
    current_height = kp_3d_start[2]  # meters
    print(f"✓ Current height: {current_height*1000:.2f} mm")
    
    # Run model inference
    print("\nRunning model inference...")
    with torch.no_grad():
        rgb_batch = rgb_tensor.unsqueeze(0).to(device)
        start_keypoint_tensor = torch.from_numpy(start_keypoint_2d).float().to(device)
        current_height_tensor = torch.tensor(current_height, dtype=torch.float32).to(device)
        pred_logits, pred_height = model(
            rgb_batch, 
            gt_target_heatmap=None, 
            training=False,
            start_keypoint_2d=start_keypoint_tensor,
            current_height=current_height_tensor
        )
    
    # Get predicted trajectories
    pred_trajectory_2d = []
    for t in range(N_WINDOW):
        pred_logits_t = pred_logits[0, t]  # (H, W)
        pred_probs_t = F.softmax(pred_logits_t.view(-1), dim=0).view_as(pred_logits_t).cpu().numpy()
        pred_y, pred_x = np.unravel_index(pred_probs_t.argmax(), pred_probs_t.shape)
        pred_2d_t = np.array([pred_x, pred_y], dtype=np.float32)
        pred_trajectory_2d.append(pred_2d_t)
    
    pred_trajectory_2d = np.array(pred_trajectory_2d)  # (N_WINDOW, 2)
    pred_height = pred_height[0].cpu().numpy()  # (N_WINDOW,)
    
    print(f"✓ Predicted trajectory (first and last):")
    print(f"  Start: [{pred_trajectory_2d[0, 0]:.1f}, {pred_trajectory_2d[0, 1]:.1f}] px, H: {pred_height[0]*1000:.2f} mm")
    print(f"  Final: [{pred_trajectory_2d[-1, 0]:.1f}, {pred_trajectory_2d[-1, 1]:.1f}] px, H: {pred_height[-1]*1000:.2f} mm")
    
    # Scale intrinsics for model resolution (exactly like test_model_ik.py)
    # Use the normalized intrinsics (either loaded from dataset or normalized from detected)
    # Then denormalize to IMAGE_SIZE (exactly like test_model_ik.py)
    cam_K_model = cam_K_norm_saved.copy()
    cam_K_model[0] *= IMAGE_SIZE  # Denormalize fx, cx to IMAGE_SIZE
    cam_K_model[1] *= IMAGE_SIZE  # Denormalize fy, cy to IMAGE_SIZE
    
    # Lift trajectory to 3D
    print("\nLifting trajectory to 3D...")
    pred_trajectory_3d = []
    for t in range(N_WINDOW):
        pred_2d_t = pred_trajectory_2d[t]
        pred_height_t = pred_height[t]
        pred_3d_t = recover_3d_from_direct_keypoint_and_height(
            pred_2d_t, pred_height_t, camera_pose_world, cam_K_model
        )
        
        if pred_3d_t is None:
            print(f"✗ Failed to lift waypoint {t} to 3D (point behind camera or ray parallel to height plane)")
            continue
        pred_trajectory_3d.append(pred_3d_t)
    
    if len(pred_trajectory_3d) == 0:
        raise RuntimeError("Could not lift any waypoints to 3D")
    
    pred_trajectory_3d = np.array(pred_trajectory_3d)  # (N, 3)
    
    # Use final waypoint for IK
    pred_3d = pred_trajectory_3d[-1]
    
    print(f"✓ Predicted 3D (final): [{pred_3d[0]:.4f}, {pred_3d[1]:.4f}, {pred_3d[2]:.4f}] m")
    
    # Setup IK (using the same mj_model/mj_data, but reset to joint state)
    print("\nSetting up IK...")
    # Reset to joint state before IK (like test_model_ik.py uses start_pos)
    mj_data.qpos[:len(mj_data.ctrl)] = joint_state[:len(mj_data.ctrl)]
    mj_data.ctrl[:] = joint_state[:len(mj_data.ctrl)]
    mujoco.mj_forward(mj_model, mj_data)
    configuration = mink.Configuration(mj_model)
    configuration.update(mj_data.qpos)
    print("✓ MuJoCo initialized with joint state")
    
    # Perform IK to predicted 3D position
    print(f"\nPerforming IK to predicted 3D position ({args.num_ik_iters} iterations)...")
    #mj_data.qpos[:6] = start_pos
    #mj_data.ctrl[:6] = start_pos
    #mujoco.mj_forward(mj_model, mj_data)
    #configuration.update(mj_data.qpos)

    optimized_qpos, ik_error = ik_to_cube_grasp(
        mj_model,
        mj_data,
        configuration, 
        pred_3d,
        initial_kp_quat_wxyz,
        num_iterations=args.num_ik_iters
    )
    print(f"✓ IK complete - Final error: {ik_error*1000:.2f} mm")
    print(f"  Optimized qpos: {optimized_qpos}")
    
    # Render result
    print("\nRendering result...")
    # Resize rgb_vis_resized back to original for visualization
    rgb_vis = cv2.resize(rgb_vis_resized, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
    H, W = rgb_vis.shape[:2]
    
    # Render robot at predicted 3D IK result
    mj_data.qpos[:6] = optimized_qpos[:6]
    mujoco.mj_forward(mj_model, mj_data)
    
    # Use original aspect ratio for rendering, then resize (like test_model_ik.py)
    render_res = [H_orig // 2, W_orig // 2]
    render_cam_K = cam_K_norm_saved.copy()
    render_cam_K[0] *= render_res[1]  # Scale fx, cx to render width
    render_cam_K[1] *= render_res[0]  # Scale fy, cy to render height
    rendered_img = render_from_camera_pose(
        mj_model, mj_data, camera_pose_world, render_cam_K, 
        render_res[0], render_res[1], segmentation=False
    ) / 255
    rendered_img = cv2.resize(rendered_img, (W, H), interpolation=cv2.INTER_LINEAR)
    
    # Create overlay
    rendered_overlay = rgb_vis * 0.5 + rendered_img * 0.5
    
    # Create visualization with trajectory panes (all timesteps + IK result)
    # Layout: N_WINDOW trajectory timesteps + 1 IK result
    fig, axes = plt.subplots(1, N_WINDOW + 1, figsize=(4*(N_WINDOW + 1), 5))
    
    # Resize rgb_vis_resized for trajectory visualization (model resolution)
    rgb_vis_traj = rgb_vis_resized.copy()
    
    # Plot trajectory timesteps
    colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
    for t in range(N_WINDOW):
        ax = axes[t]
        ax.imshow(rgb_vis_traj)
        
        # Plot trajectory up to this timestep
        if t > 0:
            for t_prev in range(t):
                ax.plot([pred_trajectory_2d[t_prev, 0], pred_trajectory_2d[t_prev+1, 0]], 
                       [pred_trajectory_2d[t_prev, 1], pred_trajectory_2d[t_prev+1, 1]], 
                       '-', color='lime', linewidth=2, alpha=0.5)
        
        # Plot current timestep keypoint
        ax.scatter(pred_trajectory_2d[t, 0], pred_trajectory_2d[t, 1], c=colors[t], s=100, 
                  marker='x', linewidths=3, label='Pred', zorder=10)
        
        height_val = pred_height[t] * 1000
        ax.set_title(f"t={t}\nH: {height_val:.2f}mm", fontsize=8)
        ax.axis('off')
        if t == 0:
            ax.legend(loc='upper right', fontsize=6)
    
    # Rendered robot at predicted 3D IK result
    axes[N_WINDOW].imshow(rendered_overlay)
    axes[N_WINDOW].set_title(f"Pred 3D IK\nIK Err: {ik_error*1000:.1f}mm", fontsize=10)
    axes[N_WINDOW].axis('off')
    
    image_name = Path(args.image_path).stem
    plt.suptitle(f"Trajectory Prediction → 3D Lifting → IK | {image_name}", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("\n✓ Test complete!")


if __name__ == "__main__":
    main()
