"""Lightweight geometry utils used by PARA training."""

import numpy as np


def _unproject_2d_to_ray(kp_2d_image, camera_pose, cam_k):
    """Unproject a 2D pixel to world-space camera ray."""
    if kp_2d_image is None:
        return None, None
    kp_h = np.array([kp_2d_image[0], kp_2d_image[1], 1.0], dtype=np.float64)
    k_inv = np.linalg.inv(cam_k)
    ray_cam = k_inv @ kp_h
    ray_cam = ray_cam / max(np.linalg.norm(ray_cam), 1e-12)
    cam_pos = np.asarray(camera_pose[:3, 3], dtype=np.float64)
    ray_world = np.asarray(camera_pose[:3, :3], dtype=np.float64) @ ray_cam
    ray_world = ray_world / max(np.linalg.norm(ray_world), 1e-12)
    return cam_pos, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_k):
    """Recover 3D point from 2D pixel and world-z height."""
    if kp_2d_image is None or height is None:
        return None
    cam_pos, ray_direction = _unproject_2d_to_ray(kp_2d_image, camera_pose, cam_k)
    if cam_pos is None or ray_direction is None or abs(ray_direction[2]) < 1e-6:
        return None
    t = (float(height) - cam_pos[2]) / ray_direction[2]
    if t < 0:
        return None
    return cam_pos + t * ray_direction

