"""Geometric transformation utilities for coordinate transformations."""
import numpy as np

def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point to 2D image coordinates."""
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = (camera_pose @ point_3d_h)[:3]
    if point_cam[2] <= 0: 
        return None
    point_2d_h = cam_K @ point_cam
    return point_2d_h[:2] / point_2d_h[2]

def rescale_coords(coords, H_orig, W_orig, H_new, W_new):
    """Rescale 2D coordinates from original image size to new size.
    
    Args:
        coords: Can be None, empty, 1D array (2,), or 2D array (N, 2)
        H_orig, W_orig: Original image dimensions
        H_new, W_new: New image dimensions
    
    Returns:
        Rescaled coordinates in same shape as input (or None if input was None)
    """
    if coords is None:
        return None
    coords = np.asarray(coords, dtype=np.float32)
    if coords.size == 0:
        return coords
    # Handle 1D case: (2,) -> reshape to (1, 2)
    if coords.ndim == 1:
        coords = coords.reshape(1, -1)
        was_1d = True
    else:
        was_1d = False
    scale_x = W_new / W_orig
    scale_y = H_new / H_orig
    coords_rescaled = np.stack(
        [coords[..., 0] * scale_x, coords[..., 1] * scale_y], axis=-1
    )
    # Return in original shape
    if was_1d:
        return coords_rescaled[0]
    return coords_rescaled

def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to a ray in robot frame.
    
    Args:
        point_2d: 2D point in image coordinates
        camera_pose: 4x4 transformation matrix from robot frame to camera frame
        cam_K: 3x3 camera intrinsics
    
    Returns:
        cam_pos_robot: Camera position in robot frame
        ray_robot: Ray direction in robot frame
    """
    cam_pose_inv = np.linalg.inv(camera_pose)
    cam_pos_robot = cam_pose_inv[:3, 3]
    cam_rot_c2r = cam_pose_inv[:3, :3]
    fx, fy = cam_K[0, 0], cam_K[1, 1]
    cx, cy = cam_K[0, 2], cam_K[1, 2]
    x_cam = (point_2d[0] - cx) / fx
    y_cam = (point_2d[1] - cy) / fy
    z_cam = 1.0
    ray_cam = np.array([x_cam, y_cam, z_cam])
    ray_cam = ray_cam / np.linalg.norm(ray_cam)
    ray_robot = cam_rot_c2r @ ray_cam
    return cam_pos_robot, ray_robot

def recover_3d_from_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D image projection and height."""
    if kp_2d_image is None or height is None:
        return None
    cam_pos, ray_image = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    if abs(ray_image[2]) < 1e-6:
        return None
    t = (height - cam_pos[2]) / ray_image[2]
    return cam_pos + t * ray_image
