# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)

import numba as nb
import numpy as np
import pypose as pp
import sim3solve
import torch
from einops import parse_shape, rearrange
from scipy.spatial.transform import Rotation as R


def make_pypose_Sim3(rot, t, s):
    q = R.from_matrix(rot).as_quat()
    data = np.concatenate([t, q, np.array(s).reshape((1,))])
    return pp.Sim3(data)


def SE3_to_Sim3(x: pp.SE3):
    out = torch.cat((x.data, torch.ones_like(x.data[..., :1])), dim=-1)
    return pp.Sim3(out)


@nb.njit(cache=True)
def _format(es):
    return np.asarray(es, dtype=np.int64).reshape((-1, 2))[1:]


@nb.njit(cache=True)
def reduce_edges(flow_mag, ii, jj, max_num_edges, nms):
    es = [(-1, -1)]

    if ii.size == 0:
        return _format(es)

    Ni, Nj = (ii.max() + 1), (jj.max() + 1)
    ignore_lookup = np.zeros((Ni, Nj), dtype=nb.bool_)

    idxs = np.argsort(flow_mag)
    for idx in idxs:  # edge index

        if len(es) > max_num_edges:
            break

        i = ii[idx]
        j = jj[idx]
        mag = flow_mag[idx]

        if (j - i) < 30:
            continue

        if mag >= 1000:  # i.e., inf
            continue

        if ignore_lookup[i, j]:
            continue

        es.append((i, j))

        for di in range(-nms, nms + 1):
            i1 = i + di

            if 0 <= i1 < Ni:
                ignore_lookup[i1, j] = True

    return _format(es)


@nb.njit(cache=True)
def umeyama_alignment(x: np.ndarray, y: np.ndarray):
    """
    The following function was copied from:
    https://github.com/MichaelGrupp/evo/blob/3067541b350528fe46375423e5bc3a7c42c06c63/evo/core/geometry.py#L35

    Computes the least squares solution parameters of an Sim(m) matrix
    that minimizes the distance between a set of registered points.
    Umeyama, Shinji: Least-squares estimation of transformation parameters
                     between two point patterns. IEEE PAMI, 1991
    :param x: mxn matrix of points, m = dimension, n = nr. of data points
    :param y: mxn matrix of points, m = dimension, n = nr. of data points
    :param with_scale: set to True to align also the scale (default: 1.0 scale)
    :return: r, t, c - rotation matrix, translation vector and scale factor
    """

    # m = dimension, n = nr. of data points
    m, n = x.shape

    # means, eq. 34 and 35
    mean_x = x.sum(axis=1) / n
    mean_y = y.sum(axis=1) / n

    # variance, eq. 36
    # "transpose" for column subtraction
    sigma_x = 1.0 / n * (np.linalg.norm(x - mean_x[:, np.newaxis]) ** 2)

    # covariance matrix, eq. 38
    outer_sum = np.zeros((m, m))
    for i in range(n):
        outer_sum += np.outer((y[:, i] - mean_y), (x[:, i] - mean_x))
    cov_xy = np.multiply(1.0 / n, outer_sum)

    # SVD (text betw. eq. 38 and 39)
    u, d, v = np.linalg.svd(cov_xy)
    if np.count_nonzero(d > np.finfo(d.dtype).eps) < m - 1:
        return None, None, None  # Degenerate covariance rank, Umeyama alignment is not possible

    # S matrix, eq. 43
    s = np.eye(m)
    if np.linalg.det(u) * np.linalg.det(v) < 0.0:
        # Ensure a RHS coordinate system (Kabsch algorithm).
        s[m - 1, m - 1] = -1

    # rotation, eq. 40
    r = u.dot(s).dot(v)

    # scale & translation, eq. 42 and 41
    c = 1 / sigma_x * np.trace(np.diag(d).dot(s))
    t = mean_y - np.multiply(c, r.dot(mean_x))

    return r, t, c


@nb.njit(cache=True)
def ransac_umeyama(src_points, dst_points, iterations=1, threshold=0.1):
    best_inliers = 0
    best_R = None
    best_t = None
    best_s = None
    for _ in range(iterations):
        # Randomly select three points
        indices = np.random.choice(src_points.shape[0], 3, replace=False)
        src_sample = src_points[indices]
        dst_sample = dst_points[indices]

        # Estimate transformation
        R, t, s = umeyama_alignment(src_sample.T, dst_sample.T)
        if t is None:
            continue

        # Apply transformation
        transformed = (src_points @ (R * s).T) + t

        # Count inliers (not ideal because depends on scene scale)
        distances = np.sum((transformed - dst_points) ** 2, axis=1) ** 0.5
        inlier_mask = distances < threshold
        inliers = np.sum(inlier_mask)

        # Update best transformation
        if inliers > best_inliers:
            best_inliers = inliers
            best_R, best_t, best_s = umeyama_alignment(
                src_points[inlier_mask].T, dst_points[inlier_mask].T
            )

    return best_R, best_t, best_s, best_inliers


def batch_jacobian(func, x):
    def _func_sum(*x):
        return func(*x).sum(dim=0)

    _, b, c = torch.autograd.functional.jacobian(_func_sum, x, vectorize=True)
    return rearrange(torch.stack((b, c)), "N O B I -> N B O I", N=2)


def _residual(C, Gi, Gj):
    assert parse_shape(C, "N _") == parse_shape(Gi, "N _") == parse_shape(Gj, "N _")
    out = C @ pp.Exp(Gi) @ pp.Exp(Gj).Inv()
    return out.Log().tensor()


def residual(Ginv, input_poses, dSloop, ii, jj, jacobian=False):

    # prep
    device = Ginv.device
    assert parse_shape(input_poses, "_ d") == dict(d=7)
    pred_inv_poses = SE3_to_Sim3(input_poses).Inv()

    # free variables
    n, _ = pred_inv_poses.shape
    kk = torch.arange(1, n, device=device)
    ll = kk - 1

    # constants
    Ti = pred_inv_poses[kk]
    Tj = pred_inv_poses[ll]
    dSij = Tj @ Ti.Inv()

    constants = torch.cat((dSij, dSloop), dim=0)
    iii = torch.cat((kk, ii))
    jjj = torch.cat((ll, jj))
    resid = _residual(constants, Ginv[iii], Ginv[jjj])

    if not jacobian:
        return resid

    J_Ginv_i, J_Ginv_j = batch_jacobian(_residual, (constants, Ginv[iii], Ginv[jjj]))
    return resid, (J_Ginv_i, J_Ginv_j, iii, jjj)


def perform_updates(
    input_poses, dSloop, ii_loop, jj_loop, iters=30, ep=0.0, lmbda=1e-6, fix_opt_window=False
):
    """Run the Levenberg Marquardt algorithm"""

    input_poses = input_poses.clone()

    if fix_opt_window:
        freen = torch.cat((ii_loop, jj_loop)).max().item() + 1
    else:
        freen = -1

    Ginv = SE3_to_Sim3(input_poses).Inv().Log()

    residual_history = []

    for itr in range(iters):
        resid, (J_Ginv_i, J_Ginv_j, iii, jjj) = residual(
            Ginv, input_poses, dSloop, ii_loop, jj_loop, jacobian=True
        )
        residual_history.append(resid.square().mean().item())
        print(f"resid: {resid.square().mean().item()}")
        (delta_pose,) = sim3solve.solve_system(
            J_Ginv_i, J_Ginv_j, iii, jjj, resid, ep, lmbda, freen
        )
        assert Ginv.shape == delta_pose.shape
        Ginv_tmp = Ginv + delta_pose

        new_resid = residual(Ginv_tmp, input_poses, dSloop, ii_loop, jj_loop)
        if new_resid.square().mean() < residual_history[-1]:
            Ginv = Ginv_tmp
            lmbda /= 2
        else:
            lmbda *= 2

        if (
            (residual_history[-1] < 1e-5)
            and (itr >= 4)
            and ((residual_history[-5] / residual_history[-1]) < 1.5)
        ):
            break

    return pp.Exp(Ginv).Inv()


def pose_refinement(pred_poses, loop_poses, loop_ii, loop_jj):

    final_est = perform_updates(pred_poses, loop_poses, loop_ii, loop_jj, iters=30)

    safe_i = loop_ii.max().item() + 1
    aa = SE3_to_Sim3(pred_poses.cpu())
    final_est = (aa[[safe_i]] * final_est[[safe_i]].Inv()) * final_est
    output = final_est[:safe_i]

    return output
