"""Multi-view geometry & proejction code.."""
import torch, torchvision
from einops import rearrange, repeat
from torch.nn import functional as F
import numpy as np

from math import ceil, log2
import torch
import torch.nn as nn
from einops import einsum
from jaxtyping import Float
from torch import Tensor

def procrustes(S1, S2,weights=None):

    # todo change to while shape>4
    if len(S1.shape)==5: 
        out = procrustes(S1.flatten(0,1),S2.flatten(0,1),weights.flatten(0,1) if weights is not None else None)
        return out[0],out[1].unflatten(0,S1.shape[:2])
    if len(S1.shape)==4:
        out = procrustes(S1.flatten(0,1),S2.flatten(0,1),weights.flatten(0,1) if weights is not None else None)
        return out[0],out[1].unflatten(0,S1.shape[:2])
    '''
    Computes a similarity transform (sR, t) that takes
    a set of 3D points S1 (BxNx3) closest to a set of 3D points, S2,
    where R is an 3x3 rotation matrix, t 3x1 translation, s scale. / mod : assuming scale is 1
    i.e. solves the orthogonal Procrutes problem.
    '''
    with torch.autocast(device_type='cuda', dtype=torch.float32):
        S1 = S1.permute(0,2,1)
        S2 = S2.permute(0,2,1)
        if weights is not None:
            weights=weights.permute(0,2,1)
        transposed = True

        if weights is None: weights = torch.ones_like(S1[:,:1])

        eps=1e-6
        weights=weights.clip(min=eps)

        # 1. Remove mean.
        weights_norm = weights/(weights.sum(-1,keepdim=True)+eps)
        mu1 = (S1*weights_norm).sum(2,keepdim=True)
        mu2 = (S2*weights_norm).sum(2,keepdim=True)

        weights_norm=weights_norm.clip(min=eps)

        X1 = S1 - mu1
        X2 = S2 - mu2

        diags = torch.diag_embed(weights.squeeze(1))

        # 3. The outer product of X1 and X2.
        K = (X1@diags).bmm(X2.permute(0,2,1))

        # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
        U, s, V = torch.svd(K)

        # Construct Z that fixes the orientation of R to get det(R)=1.
        Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
        Z = Z.repeat(U.shape[0],1,1)
        Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))

        # Construct R.
        R = V.bmm(Z.bmm(U.permute(0,2,1)))

        # 6. Recover translation.
        t = mu2 - ((R.bmm(mu1)))

        # 7. Error:
        S1_hat = R.bmm(S1) + t

        # Combine recovered transformation as single matrix
        R_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
        R_[:,:3,:3]=R
        T_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
        T_[:,:3,-1]=t.squeeze(-1)
        S_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
        transf = T_@S_@R_

        return (S1_hat-S2).square().mean(),transf


def homogenize_points(points: torch.Tensor):
    """Appends a "1" to the coordinates of a (batch of) points of dimension DIM.

    Args:
        points: points of shape (..., DIM)

    Returns:
        points_hom: points with appended "1" dimension.
    """
    ones = torch.ones_like(points[..., :1], device=points.device)
    return torch.cat((points, ones), dim=-1)


def homogenize_vecs(vectors: torch.Tensor):
    """Appends a "0" to the coordinates of a (batch of) vectors of dimension DIM.

    Args:
        vectors: vectors of shape (..., DIM)

    Returns:
        vectors_hom: points with appended "0" dimension.
    """
    zeros = torch.zeros_like(vectors[..., :1], device=vectors.device)
    return torch.cat((vectors, zeros), dim=-1)


def unproject(
    xy_pix: torch.Tensor, z: torch.Tensor, intrinsics: torch.Tensor
) -> torch.Tensor:
    """Unproject (lift) 2D pixel coordinates x_pix and per-pixel z coordinate
    to 3D points in camera coordinates.

    Args:
        xy_pix: 2D pixel coordinates of shape (..., 2)
        z: per-pixel depth, defined as z coordinate of shape (..., 1)
        intrinscis: camera intrinscics of shape (..., 3, 3)

    Returns:
        xyz_cam: points in 3D camera coordinates.
    """
    xy_pix_hom = homogenize_points(xy_pix)
    xyz_cam = torch.einsum("...ij,...kj->...ki", intrinsics.inverse(), xy_pix_hom)
    xyz_cam *= z
    return xyz_cam


def transform_world2cam(
    xyz_world_hom: torch.Tensor, cam2world: torch.Tensor
) -> torch.Tensor:
    """Transforms points from 3D world coordinates to 3D camera coordinates.

    Args:
        xyz_world_hom: homogenized 3D points of shape (..., 4)
        cam2world: camera pose of shape (..., 4, 4)

    Returns:
        xyz_cam: points in camera coordinates.
    """
    world2cam = torch.inverse(cam2world)
    return transform_rigid(xyz_world_hom, world2cam)


def transform_cam2world(
    xyz_cam_hom: torch.Tensor, cam2world: torch.Tensor
) -> torch.Tensor:
    """Transforms points from 3D world coordinates to 3D camera coordinates.

    Args:
        xyz_cam_hom: homogenized 3D points of shape (..., 4)
        cam2world: camera pose of shape (..., 4, 4)

    Returns:
        xyz_world: points in camera coordinates.
    """
    return transform_rigid(xyz_cam_hom, cam2world)


def transform_rigid(xyz_hom: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    """Apply a rigid-body transform to a (batch of) points / vectors.

    Args:
        xyz_hom: homogenized 3D points of shape (..., 4)
        T: rigid-body transform matrix of shape (..., 4, 4)

    Returns:
        xyz_trans: transformed points.
    """
    return torch.einsum("...ij,...kj->...ki", T, xyz_hom)


def get_unnormalized_cam_ray_directions(
    xy_pix: torch.Tensor, intrinsics: torch.Tensor
) -> torch.Tensor:
    return unproject(
        xy_pix,
        torch.ones_like(xy_pix[..., :1], device=xy_pix.device),
        intrinsics=intrinsics,
    )


def get_world_rays_(
    xy_pix: torch.Tensor,
    intrinsics: torch.Tensor,
    cam2world: torch.Tensor,
) -> torch.Tensor:

    if cam2world is None: 
        cam2world = torch.eye(4)[None].expand(xy_pix.size(0),-1,-1).to(xy_pix)

    # Get camera origin of camera 1
    cam_origin_world = cam2world[..., :3, -1]

    # Get ray directions in cam coordinates
    ray_dirs_cam = get_unnormalized_cam_ray_directions(xy_pix, intrinsics)
    ray_dirs_cam = ray_dirs_cam / ray_dirs_cam.norm(dim=-1, keepdim=True)

    # Homogenize ray directions
    rd_cam_hom = homogenize_vecs(ray_dirs_cam)

    # Transform ray directions to world coordinates
    rd_world_hom = transform_cam2world(rd_cam_hom, cam2world)

    cam_origin_world = repeat( cam_origin_world, "... ch -> ... num_rays ch", num_rays=ray_dirs_cam.size(-2) )

    # Return tuple of cam_origins, ray_world_directions
    return cam_origin_world, rd_world_hom[..., :3]

def get_world_rays(
    xy_pix: torch.Tensor,
    intrinsics: torch.Tensor,
    cam2world: torch.Tensor,
) -> torch.Tensor:
    if len(xy_pix.shape)==4:
        out = get_world_rays_(xy_pix.flatten(0,1),intrinsics.flatten(0,1),cam2world.flatten(0,1) if cam2world is not None else None)
        return [x.unflatten(0,xy_pix.shape[:2]) for x in out]
    return get_world_rays_(xy_pix,intrinsics,cam2world)




def get_opencv_pixel_coordinates(
    y_resolution: int,
    x_resolution: int,
    device='cpu'
):
    """For an image with y_resolution and x_resolution, return a tensor of pixel coordinates
    normalized to lie in [0, 1], with the origin (0, 0) in the top left corner,
    the x-axis pointing right, the y-axis pointing down, and the bottom right corner
    being at (1, 1).

    Returns:
        xy_pix: a meshgrid of values from [0, 1] of shape
                (y_resolution, x_resolution, 2)
    """
    i, j = torch.meshgrid(
        torch.linspace(0, 1, steps=x_resolution, device=device),
        torch.linspace(0, 1, steps=y_resolution, device=device),
    )

    xy_pix = torch.stack([i.float(), j.float()], dim=-1).permute(1, 0, 2)
    return xy_pix


def project(xyz_cam_hom: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
    """Projects homogenized 3D points xyz_cam_hom in camera coordinates
    to pixel coordinates.

    Args:
        xyz_cam_hom: 3D points of shape (..., 4)
        intrinsics: camera intrinscics of shape (..., 3, 3)

    Returns:
        xy: homogeneous pixel coordinates of shape (..., 3) (final coordinate is 1)
    """
    if len(intrinsics.shape)==len(xyz_cam_hom.shape): intrinsics=intrinsics.unsqueeze(1)
    xyw = torch.einsum("...ij,...j->...i", intrinsics, xyz_cam_hom[..., :3])
    z = xyw[..., -1:]
    xyw = xyw / (z + 1e-5)  # z-divide
    return xyw[..., :2], z

def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
    # from pytorch3d
    """
    Returns torch.sqrt(torch.max(0, x))
    but with a zero subgradient where x is 0.
    """
    ret = torch.zeros_like(x)
    positive_mask = x > 0
    ret[positive_mask] = torch.sqrt(x[positive_mask])
    return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
    # from pytorch3d
    """
    Convert rotations given as rotation matrices to quaternions.
    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).
    Returns:
        quaternions with real part first, as tensor of shape (..., 4).
    """
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

    batch_dim = matrix.shape[:-2]
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
        matrix.reshape(batch_dim + (9,)), dim=-1
    )

    q_abs = _sqrt_positive_part(
        torch.stack(
            [
                1.0 + m00 + m11 + m22,
                1.0 + m00 - m11 - m22,
                1.0 - m00 + m11 - m22,
                1.0 - m00 - m11 + m22,
            ],
            dim=-1,
        )
    )

    # we produce the desired quaternion multiplied by each of r, i, j, k
    quat_by_rijk = torch.stack(
        [
            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
        ],
        dim=-2,
    )

    # We floor here at 0.1 but the exact level is not important; if q_abs is small,
    # the candidate won't be picked.
    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
    # forall i; we pick the best-conditioned one (with the largest denominator)

    return quat_candidates[
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
    ].reshape(batch_dim + (4,))
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
    # from pytorch3d
    """
    Convert rotations given as quaternions to rotation matrices.
    Args:
        quaternions: quaternions with real part first,
            as tensor of shape (..., 4).
    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    r, i, j, k = torch.unbind(quaternions, -1)
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))
def camera_interp(camera1, camera2, t):
    if len(camera1.shape)==3:
        return torch.stack([camera_interp(cam1,cam2,t) for cam1,cam2 in zip(camera1,camera2)])
    # Extract the rotation component from the camera matrices
    q1 = matrix_to_quaternion(camera1[:3, :3])
    q2 = matrix_to_quaternion(camera2[:3, :3])

    # todo add negative quaternion check to not go long way around

    # Interpolate the quaternions using slerp
    cos_angle = (q1 * q2).sum(dim=0)
    angle = torch.acos(cos_angle.clamp(-1, 1))
    q_interpolated = (q1 * torch.sin((1 - t) * angle) + q2 * torch.sin(t * angle)) / torch.sin(angle)
    rotation_interpolated = quaternion_to_matrix(q_interpolated)

    # Interpolate the translation component
    translation_interpolated = torch.lerp(camera1[:3,-1], camera2[:3,-1], t)

    cam_interpolated = torch.eye(4)
    cam_interpolated[:3,:3]=rotation_interpolated
    cam_interpolated[:3,-1]=translation_interpolated

    return cam_interpolated.cuda()

# Positional encoding from David
class NearFarPlanePosEnc(nn.Module):
    d_near: float

    def __init__(
        self,
        d_near: float,
        d_far: float,
        num_samples: int,
    ):
        super().__init__()
        assert d_near < d_far
        self.d_near = d_near

        total_range = d_far - d_near

        # Compute the period and frequency that correspond to the distance between
        # consecutive depth samples.
        sampling_period = total_range / num_samples
        sampling_frequency = 2 * torch.pi / sampling_period

        # To avoid aliasing, the highest frequency in the positional encoding should be
        # half the sampling frequency (see
        # https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem).
        #highest_frequency = 0.5 * sampling_frequency
        highest_frequency = .5*sampling_frequency

        # Compute the number of frequencies needed to ensure that the period that
        # corresponds to the largest frequency is 4x the total_range. Beyond this point,
        # lower frequencies don't add anything to the positional encoding (at some
        # point, due to the small angle approximation, additional lower frequencies
        # don't provide any new information).
        num_frequencies = int(ceil(log2(num_samples)))

        # All other frequencies should be lower than this frequency.
        octaves = -torch.arange(num_frequencies).float()
        frequencies = 2**octaves * highest_frequency
        self.register_buffer("frequencies", frequencies, persistent=False)

    def forward(
        self,
        points: Float[Tensor, "batch point"],
    ) -> Float[Tensor, "batch point embedded_dim"]:
        offset_points = points - self.d_near
        scaled_points = einsum(offset_points, self.frequencies, "... p, m -> ... p m")
        return torch.cat((torch.stack((torch.sin(scaled_points), torch.cos(scaled_points)), -1).flatten(-2,-1),points.unsqueeze(-1)),-1)

    @property
    def d_out(self):
        return len(self.frequencies) * 2 + 1

    @property
    def periods(self) -> Float[Tensor, " frequency"]:
        return 2 * torch.pi / self.frequencies


def numpy_procrustes(X, Y, scaling=True, reflection='best'):

    n,m = X.shape
    ny,my = Y.shape

    muX = X.mean(0)
    muY = Y.mean(0)

    X0 = X - muX
    Y0 = Y - muY

    ssX = (X0**2.).sum()
    ssY = (Y0**2.).sum()

    # centred Frobenius norm
    normX = np.sqrt(ssX)
    normY = np.sqrt(ssY)

    # scale to equal (unit) norm
    X0 /= normX
    Y0 /= normY

    if my < m:
        Y0 = np.concatenate((Y0, np.zeros(n, m-my)),0)

    # optimum rotation matrix of Y
    A = np.dot(X0.T, Y0)
    U,s,Vt = np.linalg.svd(A,full_matrices=False)
    V = Vt.T
    T = np.dot(V, U.T)

    if reflection != 'best':

        # does the current solution use a reflection?
        have_reflection = np.linalg.det(T) < 0

        # if that's not what was specified, force another reflection
        if reflection != have_reflection:
            V[:,-1] *= -1
            s[-1] *= -1
            T = np.dot(V, U.T)

    traceTA = s.sum()

    if scaling:

        # optimum scaling of Y
        b = traceTA * normX / normY

        # standarised distance between X and b*Y*T + c
        d = 1 - traceTA**2

        # transformed coords
        Z = normX*traceTA*np.dot(Y0, T) + muX

    else:
        b = 1
        d = 1 + ssY/ssX - 2 * traceTA * normY / normX
        Z = normY*np.dot(Y0, T) + muX

    # transformation matrix
    if my < m:
        T = T[:my,:]
    c = muX - b*np.dot(muY, T)

    #transformation values
    tform = {'rotation':T, 'scale':b, 'translation':c}

    #R_=torch.eye(4).numpy()
    #R_[:3,:3]=T
    #T_=torch.eye(4).numpy()
    #T_[:3,-1]=c
    #S_=torch.eye(4).numpy()*b
    #transf = T_@S_@R_

    return d, Z, tform

def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
    """
    Returns torch.sqrt(torch.max(0, x))
    but with a zero subgradient where x is 0.
    """
    ret = torch.zeros_like(x)
    positive_mask = x > 0
    ret[positive_mask] = torch.sqrt(x[positive_mask])
    return ret


def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to quaternions.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).

    Returns:
        quaternions with real part first, as tensor of shape (..., 4).
    """
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

    batch_dim = matrix.shape[:-2]
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
        matrix.reshape(batch_dim + (9,)), dim=-1
    )

    q_abs = _sqrt_positive_part(
        torch.stack(
            [
                1.0 + m00 + m11 + m22,
                1.0 + m00 - m11 - m22,
                1.0 - m00 + m11 - m22,
                1.0 - m00 - m11 + m22,
            ],
            dim=-1,
        )
    )

    # we produce the desired quaternion multiplied by each of r, i, j, k
    quat_by_rijk = torch.stack(
        [
            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
            #  `int`.
            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
            #  `int`.
            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
            #  `int`.
            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
            #  `int`.
            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
        ],
        dim=-2,
    )

    # We floor here at 0.1 but the exact level is not important; if q_abs is small,
    # the candidate won't be picked.
    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
    # forall i; we pick the best-conditioned one (with the largest denominator)

    return quat_candidates[
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
    ].reshape(batch_dim + (4,))

def _angle_from_tan(
    axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
    """
    Extract the first or third Euler angle from the two members of
    the matrix which are positive constant times its sine and cosine.

    Args:
        axis: Axis label "X" or "Y or "Z" for the angle we are finding.
        other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
            convention.
        data: Rotation matrices as tensor of shape (..., 3, 3).
        horizontal: Whether we are looking for the angle for the third axis,
            which means the relevant entries are in the same row of the
            rotation matrix. If not, they are in the same column.
        tait_bryan: Whether the first and third axes in the convention differ.

    Returns:
        Euler Angles in radians for each matrix in data as a tensor
        of shape (...).
    """

    i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
    if horizontal:
        i2, i1 = i1, i2
    even = (axis + other_axis) in ["XY", "YZ", "ZX"]
    if horizontal == even:
        return torch.atan2(data[..., i1], data[..., i2])
    if tait_bryan:
        return torch.atan2(-data[..., i2], data[..., i1])
    return torch.atan2(data[..., i2], -data[..., i1])
def matrix_to_euler_angles(matrix: torch.Tensor, convention="XYZ") -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to Euler angles in radians.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).
        convention: Convention string of three uppercase letters.

    Returns:
        Euler angles in radians as tensor of shape (..., 3).
    """

    def _index_from_letter(letter: str) -> int:
        if letter == "X":
            return 0
        if letter == "Y":
            return 1
        if letter == "Z":
            return 2
        raise ValueError("letter must be either X, Y or Z.")
    i0 = _index_from_letter(convention[0])
    i2 = _index_from_letter(convention[2])
    tait_bryan = i0 != i2
    if tait_bryan:
        central_angle = torch.asin(
            matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
        )
    else:
        central_angle = torch.acos(matrix[..., i0, i0])

    o = (
        _angle_from_tan(
            convention[0], convention[1], matrix[..., i2], False, tait_bryan
        ),
        central_angle,
        _angle_from_tan(
            convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
        ),
    )
    return torch.stack(o, -1)


def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
    """
    Return the rotation matrices for one of the rotations about an axis
    of which Euler angles describe, for each value of the angle given.

    Args:
        axis: Axis label "X" or "Y or "Z".
        angle: any shape tensor of Euler angles in radians

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """

    cos = torch.cos(angle)
    sin = torch.sin(angle)
    one = torch.ones_like(angle)
    zero = torch.zeros_like(angle)

    if axis == "X":
        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
    elif axis == "Y":
        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
    elif axis == "Z":
        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
    else:
        raise ValueError("letter must be either X, Y or Z.")

    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as Euler angles in radians to rotation matrices.

    Args:
        euler_angles: Euler angles in radians as tensor of shape (..., 3).
        convention: Convention string of three uppercase letters from
            {"X", "Y", and "Z"}.

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    convention="XYZ"
    matrices = [
        _axis_angle_rotation(c, e)
        for c, e in zip(convention, torch.unbind(euler_angles, -1))
    ]
    # return functools.reduce(torch.matmul, matrices)
    return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
def lift_to_poses(poses_,scale=1):
    poses = torch.zeros(*poses_.shape[:-2],poses_.size(-2),4,4).cuda()
    poses[...,-1,-1]=1
    if poses_.size(-1)==7: poses_=torch.cat((poses_[...,:-3]/poses_[...,:-3].norm(dim=-1,keepdim=True),poses_[...,-3:]),-1)
    poses[...,:3,:3] = (euler_angles_to_matrix if poses_.size(-1)==6 else quaternion_to_matrix)(poses_[...,:-3])
    poses[...,:3,-1] = poses_[...,-3:]*scale
    return poses
