"""Geometric utility functions for 3D transformations."""
import numpy as np


def procrustes_alignment(X, Y):
    """
    Compute optimal rigid transformation (R, t, s) that aligns Y to X
    using Procrustes analysis.
    
    Args:
        X: Target point cloud, shape (N, 3)
        Y: Source point cloud, shape (N, 3)
    
    Returns:
        T: 4x4 transformation matrix that transforms Y to X
        scale: Scale factor
        R: 3x3 rotation matrix
        t: Translation vector (3,)
    """
    # Center the point clouds
    X_centered = X - np.mean(X, axis=0)
    Y_centered = Y - np.mean(Y, axis=0)
    
    # Compute rotation using SVD
    H = Y_centered.T @ X_centered
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    
    # Ensure proper rotation (det(R) = 1)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T
    
    # Compute scale using a more robust method
    # Scale based on the ratio of point cloud sizes
    X_scale = np.sqrt(np.sum(X_centered**2) / len(X_centered))
    Y_scale = np.sqrt(np.sum(Y_centered**2) / len(Y_centered))
    scale = X_scale / Y_scale if Y_scale > 0 else 1.0
    
    # Alternative: use the original Procrustes scale but with better normalization
    # scale = np.trace(R.T @ H) / np.trace(Y_centered.T @ Y_centered)
    
    # Compute translation
    t = np.mean(X, axis=0) - scale * R @ np.mean(Y, axis=0)
    
    # Create transformation matrix
    T = np.eye(4)
    T[:3, :3] = scale * R
    T[:3, 3] = t
    
    return T, scale, R, t

