# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)

import bisect
import glob
import os
import numpy as np
import trimesh
from loop_utils.alignment_torch import robust_weighted_estimate_sim3_torch
from loop_utils.alignment_triton import robust_weighted_estimate_sim3_triton
from numba import njit
from sklearn.linear_model import LinearRegression, RANSACRegressor


def accumulate_sim3_transforms(transforms):
    """
    Accumulate adjacent SIM(3) transforms into transforms
    from the initial frame to each subsequent frame.

    Args:
    transforms: list, each element is a tuple (R, s, t)
        R: 3x3 rotation matrix (np.array)
        s: scale factor (scalar)
        t: 3x1 translation vector (np.array)

    Returns:
    Cumulative transforms list, each element is (R_cum, s_cum, t_cum)
        representing the transform from frame 0 to frame k
    """
    if not transforms:
        return []

    cumulative_transforms = [transforms[0]]

    for i in range(1, len(transforms)):
        s_cum_prev, R_cum_prev, t_cum_prev = cumulative_transforms[i - 1]
        s_next, R_next, t_next = transforms[i]
        R_cum_new = R_cum_prev @ R_next
        s_cum_new = s_cum_prev * s_next
        t_cum_new = s_cum_prev * (R_cum_prev @ t_next) + t_cum_prev
        cumulative_transforms.append((s_cum_new, R_cum_new, t_cum_new))

    return cumulative_transforms


def estimate_sim3(source_points, target_points):
    mu_src = np.mean(source_points, axis=0)
    mu_tgt = np.mean(target_points, axis=0)

    src_centered = source_points - mu_src
    tgt_centered = target_points - mu_tgt

    scale_src = np.sqrt((src_centered**2).sum(axis=1).mean())
    scale_tgt = np.sqrt((tgt_centered**2).sum(axis=1).mean())
    s = scale_tgt / scale_src

    src_scaled = src_centered * s

    H = src_scaled.T @ tgt_centered
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = Vt.T @ U.T

    t = mu_tgt - s * R @ mu_src
    return s, R, t


def align_point_maps(point_map1, conf1, point_map2, conf2, conf_threshold):
    """point_map2 -> point_map1"""
    b1, _, _, _ = point_map1.shape
    b2, _, _, _ = point_map2.shape
    b = min(b1, b2)

    aligned_points1 = []
    aligned_points2 = []

    for i in range(b):
        mask1 = conf1[i] > conf_threshold
        mask2 = conf2[i] > conf_threshold
        valid_mask = mask1 & mask2

        idx = np.where(valid_mask)
        if len(idx[0]) == 0:
            continue

        pts1 = point_map1[i][idx]
        pts2 = point_map2[i][idx]

        aligned_points1.append(pts1)
        aligned_points2.append(pts2)

    if len(aligned_points1) == 0:
        raise ValueError("No matching point pairs were found!")

    all_pts1 = np.concatenate(aligned_points1, axis=0)
    all_pts2 = np.concatenate(aligned_points2, axis=0)

    print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
    s, R, t = estimate_sim3(all_pts2, all_pts1)

    mean_error = compute_alignment_error(
        point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
    )
    print(f"Mean error: {mean_error}")

    return s, R, t


def apply_sim3(points, s, R, t):
    return (s * (R @ points.T)).T + t


def apply_sim3_direct(point_maps, s, R, t):
    # point_maps: (b, h, w, 3) -> (b, h, w, 3, 1)
    point_maps_expanded = point_maps[..., np.newaxis]  # (b, h, w, 3, 1)

    # R: (3, 3) -> (b, h, w, 3, 1) = (3, 3) @ (3, 1)
    rotated = np.matmul(R, point_maps_expanded)  # (b, h, w, 3, 1)
    rotated = rotated.squeeze(-1)  # (b, h, w, 3)
    transformed = s * rotated + t  # (b, h, w, 3)

    return transformed


def compute_alignment_error(point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t):
    """
    Compute the average point alignment error (using only original inputs)

    Args:
    point_map1: target point map (b, h, w, 3)
    conf1: target confidence map (b, h, w)
    point_map2: source point map (b, h, w, 3)
    conf2: source confidence map (b, h, w)
    conf_threshold: confidence threshold
    s, R, t: transformation parameters
    """
    b1, h1, w1, _ = point_map1.shape
    b2, h2, w2, _ = point_map2.shape
    b = min(b1, b2)
    h = min(h1, h2)
    w = min(w1, w2)

    target_points = []
    source_points = []

    for i in range(b):
        mask1 = conf1[i, :h, :w] > conf_threshold
        mask2 = conf2[i, :h, :w] > conf_threshold
        valid_mask = mask1 & mask2

        idx = np.where(valid_mask)
        if len(idx[0]) == 0:
            continue

        t_pts = point_map1[i, :h, :w][idx]
        s_pts = point_map2[i, :h, :w][idx]

        target_points.append(t_pts)
        source_points.append(s_pts)

    if len(target_points) == 0:
        print("Warning: No matching point pairs found for error calculation")
        return np.nan

    all_target = np.concatenate(target_points, axis=0)
    all_source = np.concatenate(source_points, axis=0)

    transformed = (s * (R @ all_source.T)).T + t

    errors = np.linalg.norm(transformed - all_target, axis=1)

    mean_error = np.mean(errors)
    std_error = np.std(errors)
    median_error = np.median(errors)
    max_error = np.max(errors)

    print(
        f"Alignment error statistics [using {len(errors)} points]: "
        f"mean={mean_error:.4f}, std={std_error:.4f}, "
        f"median={median_error:.4f}, max={max_error:.4f}"
    )

    return mean_error


def save_confident_pointcloud(
    points, colors, confs, output_path, conf_threshold, sample_ratio=1.0
):
    """
    Filter points based on confidence threshold
    and save as PLY file, with optional random sampling ratio.

    Args:
    - points: np.ndarray, shape (H, W, 3) or (N, 3)
    - colors: np.ndarray, shape (H, W, 3) or (N, 3)
    - confs: np.ndarray, shape (H, W) or (N,)
    - output_path: str, output PLY file path
    - conf_threshold: float, confidence threshold for point filtering
    - sample_ratio: float, sampling ratio (0 < sample_ratio <= 1.0)
    """
    points = points.reshape(-1, 3).astype(np.float32, copy=False)
    colors = colors.reshape(-1, 3).astype(np.uint8, copy=False)
    confs = confs.reshape(-1).astype(np.float32, copy=False)

    conf_mask = (confs >= conf_threshold) & (confs > 1e-5)
    points = points[conf_mask]
    colors = colors[conf_mask]

    if 0 < sample_ratio < 1.0 and len(points) > 0:
        num_samples = int(len(points) * sample_ratio)
        indices = np.random.choice(len(points), num_samples, replace=False)
        points = points[indices]
        colors = colors[indices]

    os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)

    print(f"shape of sampled point: {points.shape}")
    trimesh.PointCloud(points, colors=colors).export(output_path)
    print(f"Saved point cloud with {len(points)} points to {output_path}")


def save_confident_pointcloud_batch(
    points, colors, confs, output_path, conf_threshold, sample_ratio=1.0, batch_size=1000000
):
    """
    - points: np.ndarray,  (b, H, W, 3) / (N, 3)
    - colors: np.ndarray,  (b, H, W, 3) / (N, 3)
    - confs: np.ndarray,  (b, H, W) / (N,)
    - output_path: str
    - conf_threshold: float,
    - sample_ratio: float (0 < sample_ratio <= 1.0)
    - batch_size: int
    """
    if points.ndim == 2:
        b = 1
        points = points[np.newaxis, ...]
        colors = colors[np.newaxis, ...]
        confs = confs[np.newaxis, ...]
    elif points.ndim == 4:
        b = points.shape[0]
    else:
        raise ValueError("Unsupported points dimension. Must be 2 (N,3) or 4 (b,H,W,3)")

    total_valid = 0
    for i in range(b):
        cfs = confs[i].reshape(-1)
        total_valid += np.count_nonzero((cfs >= conf_threshold) & (cfs > 1e-5))

    num_samples = int(total_valid * sample_ratio) if sample_ratio < 1.0 else total_valid

    if num_samples == 0:
        save_ply(np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8), output_path)
        return

    if sample_ratio == 1.0:
        with open(output_path, "wb") as f:
            write_ply_header(f, num_samples)

            for i in range(b):
                pts = points[i].reshape(-1, 3).astype(np.float32)
                cls = colors[i].reshape(-1, 3).astype(np.uint8)
                cfs = confs[i].reshape(-1).astype(np.float32)

                mask = (cfs >= conf_threshold) & (cfs > 1e-5)
                valid_pts = pts[mask]
                valid_cls = cls[mask]

                for j in range(0, len(valid_pts), batch_size):
                    batch_pts = valid_pts[j : j + batch_size]
                    batch_cls = valid_cls[j : j + batch_size]
                    write_ply_batch(f, batch_pts, batch_cls)

    else:
        reservoir_pts = np.zeros((num_samples, 3), dtype=np.float32)
        reservoir_clr = np.zeros((num_samples, 3), dtype=np.uint8)
        count = 0

        for i in range(b):
            pts = points[i].reshape(-1, 3).astype(np.float32)
            cls = colors[i].reshape(-1, 3).astype(np.uint8)
            cfs = confs[i].reshape(-1).astype(np.float32)

            mask = (cfs >= conf_threshold) & (cfs > 1e-5)
            valid_pts = pts[mask]
            valid_cls = cls[mask]
            n_valid = len(valid_pts)

            if count < num_samples:
                fill_count = min(num_samples - count, n_valid)

                reservoir_pts[count : count + fill_count] = valid_pts[:fill_count]
                reservoir_clr[count : count + fill_count] = valid_cls[:fill_count]
                count += fill_count

                if fill_count < n_valid:
                    remaining_pts = valid_pts[fill_count:]
                    remaining_cls = valid_cls[fill_count:]

                    count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
                        remaining_pts, remaining_cls, count, reservoir_pts, reservoir_clr
                    )
            else:
                count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
                    valid_pts, valid_cls, count, reservoir_pts, reservoir_clr
                )

        save_ply(reservoir_pts, reservoir_clr, output_path)


""" The following function is deprecated"""

# def vectorized_reservoir_sampling(new_pts, new_cls, current_count, reservoir_pts, reservoir_clr):
#     """
#     - new_pts:  (M, 3)
#     - new_cls:  (M, 3)
#     - current_count
#     - reservoir_pts:  (K, 3)
#     - reservoir_clr:  (K, 3)

#     """
#     k = len(reservoir_pts)
#     n_new = len(new_pts)

#     rand_indices = np.random.randint(0, current_count + n_new, size=n_new)

#     replace_mask = rand_indices < k
#     replace_indices = rand_indices[replace_mask]
#     replace_pts = new_pts[replace_mask]
#     replace_cls = new_cls[replace_mask]

#     reservoir_pts[replace_indices] = replace_pts
#     reservoir_clr[replace_indices] = replace_cls

#     return current_count + n_new, reservoir_pts, reservoir_clr


"""
    Function `vectorized_reservoir_sampling`  is not mathematically accurate in sampling.
    This leads to inconsistent density in the downsampled point clouds.
    The `optimized_vectorized_reservoir_sampling` function has fixed this bug.

    Special thanks to @Horace89 for the detailed analysis and code assistance.

    See https://github.com/DengKaiCQ/VGGT-Long/issues/28 for details
"""


def optimized_vectorized_reservoir_sampling(
    new_points: np.ndarray,
    new_colors: np.ndarray,
    current_count: int,
    reservoir_points: np.ndarray,
    reservoir_colors: np.ndarray,
) -> tuple[int, np.ndarray, np.ndarray]:
    """
    Optimized vectorized reservoir sampling with batch probability calculations.

    This maintains mathematical correctness while improving performance through
    vectorized operations where possible.

    Args:
        new_points: New point coordinates to consider, shape (M, 3)
        new_colors: New point colors to consider, shape (M, 3)
        current_count: Number of elements seen so far
        reservoir_points: Current reservoir of sampled points, shape (K, 3)
        reservoir_colors: Current reservoir of sampled colors, shape (K, 3)

    Returns:
        Tuple of (updated_count, updated_reservoir_points, updated_reservoir_colors)
    """
    random_gen = np.random

    reservoir_size = len(reservoir_points)
    num_new_points = len(new_points)

    if num_new_points == 0:
        return current_count, reservoir_points, reservoir_colors

    # Calculate sequential indices for each new point
    point_indices = np.arange(current_count + 1, current_count + num_new_points + 1)

    # Generate random numbers for each point
    random_values = random_gen.randint(0, point_indices, size=num_new_points)

    # Determine which points should replace reservoir elements
    replacement_mask = random_values < reservoir_size
    replacement_positions = random_values[replacement_mask]

    # Apply replacements
    if np.any(replacement_mask):
        points_to_replace = new_points[replacement_mask]
        colors_to_replace = new_colors[replacement_mask]

        reservoir_points[replacement_positions] = points_to_replace
        reservoir_colors[replacement_positions] = colors_to_replace

    return current_count + num_new_points, reservoir_points, reservoir_colors


def write_ply_header(f, num_vertices):
    header = [
        "ply",
        "format binary_little_endian 1.0",
        f"element vertex {num_vertices}",
        "property float x",
        "property float y",
        "property float z",
        "property uchar red",
        "property uchar green",
        "property uchar blue",
        "end_header",
    ]
    f.write("\n".join(header).encode() + b"\n")


def write_ply_batch(f, points, colors):
    structured = np.zeros(
        len(points),
        dtype=[
            ("x", np.float32),
            ("y", np.float32),
            ("z", np.float32),
            ("red", np.uint8),
            ("green", np.uint8),
            ("blue", np.uint8),
        ],
    )

    structured["x"] = points[:, 0]
    structured["y"] = points[:, 1]
    structured["z"] = points[:, 2]
    structured["red"] = colors[:, 0]
    structured["green"] = colors[:, 1]
    structured["blue"] = colors[:, 2]

    f.write(structured.tobytes())


def save_ply(points, colors, filename):
    with open(filename, "wb") as f:
        write_ply_header(f, len(points))
        write_ply_batch(f, points, colors)


def find_chunk_index(chunks, idx):
    """
    Find the 0-based chunk index that contains the given index idx.
    chunks: List of (begin_idx, end_idx).
    idx: The index to search for.
    Returns the 0-based chunk index.
    """
    starts = [chunk[0] for chunk in chunks]
    pos = bisect.bisect_right(starts, idx) - 1  # Find position of idx in starts
    if pos < 0 or pos >= len(chunks):
        raise ValueError(f"Index {idx} not found in any chunk")
    chunk_begin, chunk_end = chunks[pos]
    if idx < chunk_begin or idx > chunk_end:
        raise ValueError(f"Index {idx} not found in any chunk")
    return pos


def get_frame_range(chunk, idx, half_window=10):
    """
    Calculate the frame range centered at idx with half_window
    frames on each side within chunk boundaries.
    If near boundaries, take 2 * half_window frames starting from the boundary.
    chunk: (begin_idx, end_idx).
    idx: Center index.
    half_window: Number of frames to take on each side of center index.
    Returns (start, end).
    """
    begin, end = chunk
    window_size = 2 * half_window

    if idx - half_window < begin:
        start = begin
        end_candidate = begin + window_size
        end = min(end, end_candidate)

    elif idx + half_window > end:
        end_candidate = end
        start_candidate = end - window_size
        start = max(begin, start_candidate)

    else:
        start = idx - half_window
        end = idx + half_window
    return (start, end)


def process_loop_list(chunk_index, loop_list, half_window=10):
    """
    Process loop_list and return chunk indices and frame ranges for each (idx1, idx2) pair.
    chunk_index: List of (begin_idx, end_idx) tuples.
    loop_list: List of (idx1, idx2) tuples.
    half_window: Number of frames to take on each side of center index (default 10).
    Returns list of (chunk_idx1, range1, chunk_idx2, range2) tuples where:
      - chunk_idx1, chunk_idx2: Chunk indices (1-based).
      - range1, range2: Frame range tuples (start, end).
    """
    results = []
    for idx1, idx2 in loop_list:
        try:
            chunk_idx1_0based = find_chunk_index(chunk_index, idx1)
            chunk1 = chunk_index[chunk_idx1_0based]
            range1 = get_frame_range(chunk1, idx1, half_window)

            chunk_idx2_0based = find_chunk_index(chunk_index, idx2)
            chunk2 = chunk_index[chunk_idx2_0based]
            range2 = get_frame_range(chunk2, idx2, half_window)

            result = (chunk_idx1_0based, range1, chunk_idx2_0based, range2)
            results.append(result)
        except ValueError as e:
            print(f"Skipping pair ({idx1}, {idx2}): {e}")
    return results


def compute_sim3_ab(S_a, S_b):

    s_a, R_a, T_a = S_a
    s_b, R_b, T_b = S_b

    s_ab = s_b / s_a
    R_ab = R_b @ R_a.T
    T_ab = T_b - s_ab * (R_ab @ T_a)

    return (s_ab, R_ab, T_ab)


def merge_ply_files(input_dir, output_path):
    """
    Merge all PLY files in a directory into one file (without loading into memory)

    Args:
    - input_dir: Input directory containing multiple '{idx}_pcd.ply' files
    - output_path: Output file path (e.g., 'combined.ply')
    """

    print("Merging PLY files...")

    input_files = sorted(glob.glob(os.path.join(input_dir, "*_pcd.ply")))

    if not input_files:
        print("No PLY files found")
        return

    idx_file = 0
    len(input_files)

    total_vertices = 0
    for file in input_files:  # Count total vertices
        with open(file, "rb") as f:
            for line in f:
                if line.startswith(b"element vertex"):
                    vertex_count = int(line.split()[-1])
                    total_vertices += vertex_count
                elif line.startswith(b"end_header"):
                    break

    with open(output_path, "wb") as out_f:
        # Write new header
        out_f.write(b"ply\n")
        out_f.write(b"format binary_little_endian 1.0\n")
        out_f.write(f"element vertex {total_vertices}\n".encode())
        out_f.write(b"property float x\n")
        out_f.write(b"property float y\n")
        out_f.write(b"property float z\n")
        out_f.write(b"property uchar red\n")
        out_f.write(b"property uchar green\n")
        out_f.write(b"property uchar blue\n")
        out_f.write(b"end_header\n")

        for file in input_files:
            print(f"Processing {idx_file}/{len(input_files)}: {file}")
            idx_file += 1
            with open(file, "rb") as in_f:
                # Skip the head
                in_header = True
                while in_header:
                    line = in_f.readline()
                    if line.startswith(b"end_header"):
                        in_header = False
                data = in_f.read()
                out_f.write(data)

    print(f"Merge completed! Total points: {total_vertices}")
    print(f"Output file: {output_path}")


def weighted_estimate_se3(source_points, target_points, weights):
    """
    source_points:  (Nx3)
    target_points:  (Nx3)
    :weights:  (N,) [0,1]
    """
    total_weight = np.sum(weights)
    if total_weight < 1e-6:
        raise ValueError("Total weight too small for meaningful estimation")

    normalized_weights = weights / total_weight

    mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
    mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)

    src_centered = source_points - mu_src
    tgt_centered = target_points - mu_tgt

    weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
    weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]

    H = weighted_src.T @ weighted_tgt

    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = Vt.T @ U.T

    t = mu_tgt - R @ mu_src

    return 1.0, R, t


def weighted_estimate_sim3(source_points, target_points, weights):
    """
    source_points:  (Nx3)
    target_points:  (Nx3)
    :weights:  (N,) [0,1]
    """
    total_weight = np.sum(weights)
    if total_weight < 1e-6:
        raise ValueError("Total weight too small for meaningful estimation")

    normalized_weights = weights / total_weight

    mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
    mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)

    src_centered = source_points - mu_src
    tgt_centered = target_points - mu_tgt

    scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
    scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
    s = scale_tgt / scale_src

    weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
    weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]

    H = weighted_src.T @ weighted_tgt

    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = Vt.T @ U.T

    t = mu_tgt - s * R @ mu_src
    return s, R, t


def huber_loss(r, delta):
    abs_r = np.abs(r)
    return np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))


def robust_weighted_estimate_sim3(
    src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
):
    """
    src:  (Nx3)
    tgt:  (Nx3)
    init_weights:  (N,)
    """
    if align_method == "sim3":
        s, R, t = weighted_estimate_sim3(src, tgt, init_weights)
    elif align_method == "se3" or align_method == "scale+se3":
        s, R, t = weighted_estimate_se3(src, tgt, init_weights)

    prev_error = float("inf")

    for iter in range(max_iters):

        transformed = s * (src @ R.T) + t
        residuals = np.linalg.norm(tgt - transformed, axis=1)  # (N,)
        print(f"Residuals: {np.mean(residuals)}")

        abs_res = np.abs(residuals)
        huber_weights = np.ones_like(residuals)
        large_res_mask = abs_res > delta
        huber_weights[large_res_mask] = delta / abs_res[large_res_mask]

        combined_weights = init_weights * huber_weights

        combined_weights /= np.sum(combined_weights) + 1e-12

        if align_method == "se3":
            s_new, R_new, t_new = weighted_estimate_se3(src, tgt, combined_weights)
        elif align_method == "sim3" or align_method == "scale+se3":
            s_new, R_new, t_new = weighted_estimate_sim3(src, tgt, combined_weights)

        param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
        rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
        current_error = np.sum(huber_loss(residuals, delta) * init_weights)

        if (param_change < tol and rot_angle < np.radians(0.1)) or (
            abs(prev_error - current_error) < tol * prev_error
        ):
            break

        s, R, t = s_new, R_new, t_new
        prev_error = current_error

    return s, R, t


# ===== Speed Up Begin =====


@njit(cache=True)
def _weighted_estimate_se3_numba(source_points, target_points, weights):
    # Ensure float32
    source_points = source_points.astype(np.float32)
    target_points = target_points.astype(np.float32)
    weights = weights.astype(np.float32)

    total_weight = np.sum(weights)
    if total_weight < 1e-6:
        return (
            1.0,
            np.zeros(3, dtype=np.float32),
            np.zeros(3, dtype=np.float32),
            np.zeros((3, 3), dtype=np.float32),
        )

    normalized_weights = weights / total_weight

    mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
    mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)

    src_centered = source_points - mu_src
    tgt_centered = target_points - mu_tgt

    weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
    weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]

    H = weighted_src.T @ weighted_tgt

    return 1.0, mu_src, mu_tgt, H


@njit(cache=True)
def _weighted_estimate_sim3_numba(source_points, target_points, weights):
    # Ensure float32
    source_points = source_points.astype(np.float32)
    target_points = target_points.astype(np.float32)
    weights = weights.astype(np.float32)

    total_weight = np.sum(weights)
    if total_weight < 1e-6:
        return (
            -1.0,
            np.zeros(3, dtype=np.float32),
            np.zeros(3, dtype=np.float32),
            np.zeros((3, 3), dtype=np.float32),
        )

    normalized_weights = weights / total_weight

    mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
    mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)

    src_centered = source_points - mu_src
    tgt_centered = target_points - mu_tgt

    scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
    scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
    s = scale_tgt / scale_src

    weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
    weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]

    H = weighted_src.T @ weighted_tgt

    return s, mu_src, mu_tgt, H


def weighted_estimate_sim3_numba(source_points, target_points, weights, align_method="sim3"):
    if align_method == "sim3":
        s, mu_src, mu_tgt, H = _weighted_estimate_sim3_numba(source_points, target_points, weights)
    elif align_method == "se3" or align_method == "scale+se3":
        s, mu_src, mu_tgt, H = _weighted_estimate_se3_numba(source_points, target_points, weights)

    if s < 0:
        raise ValueError("Total weight too small for meaningful estimation")

    # Ensure float32
    H = H.astype(np.float32)
    U, _, Vt = np.linalg.svd(H.astype(np.float32))  # float32 SVD

    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = Vt.T @ U.T

    if align_method == "se3" or align_method == "scale+se3":
        t = mu_tgt - R @ mu_src
    else:
        t = mu_tgt - s * R @ mu_src

    return s, R, t


@njit(cache=True)
def huber_loss_numba(r, delta):
    r = r.astype(np.float32)
    delta = np.float32(delta)
    abs_r = np.abs(r)
    result = np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
    return result.astype(np.float32)


@njit(cache=True)
def compute_residuals_numba(tgt, transformed):
    residuals = np.empty(tgt.shape[0], dtype=np.float32)
    for i in range(tgt.shape[0]):
        diff = tgt[i] - transformed[i]
        residuals[i] = np.sqrt(np.sum(diff**2))
    return residuals


@njit(cache=True)
def compute_huber_weights_numba(residuals, delta):
    weights = np.ones(residuals.shape, dtype=np.float32)
    for i in range(residuals.shape[0]):
        r = residuals[i]
        if r > delta:
            weights[i] = delta / r
    return weights


@njit(cache=True)
def apply_transformation_numba(src, s, R, t):
    transformed = np.empty_like(src)
    for i in range(src.shape[0]):
        p = src[i]
        transformed[i] = s * (R @ p) + t
    return transformed


def robust_weighted_estimate_sim3_numba(
    src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
):
    src = src.astype(np.float32)
    tgt = tgt.astype(np.float32)
    init_weights = init_weights.astype(np.float32)

    s, R, t = weighted_estimate_sim3_numba(src, tgt, init_weights, align_method=align_method)

    prev_error = float("inf")

    for iter in range(max_iters):
        transformed = apply_transformation_numba(src, s, R, t)
        residuals = compute_residuals_numba(tgt, transformed)

        print(f"Residuals: {np.mean(residuals)}")

        huber_weights = compute_huber_weights_numba(residuals, delta)
        combined_weights = init_weights * huber_weights
        combined_weights /= np.sum(combined_weights) + 1e-12

        s_new, R_new, t_new = weighted_estimate_sim3_numba(
            src, tgt, combined_weights, align_method=align_method
        )

        param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
        rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))

        current_error = np.sum(huber_loss_numba(residuals, delta) * init_weights)

        if (param_change < tol and rot_angle < np.radians(0.1)) or (
            abs(prev_error - current_error) < tol * prev_error
        ):
            break

        s, R, t = s_new, R_new, t_new
        prev_error = current_error

    return s, R, t


def warmup_numba():

    print("\nWarming up Numba JIT-compiled functions...")

    src = np.random.randn(50000, 3).astype(np.float32)
    tgt = np.random.randn(50000, 3).astype(np.float32)
    weights = np.ones(50000, dtype=np.float32)
    residuals = np.abs(np.random.randn(50000).astype(np.float32))
    R = np.eye(3, dtype=np.float32)
    t = np.zeros(3, dtype=np.float32)
    s = np.float32(1.0)
    delta = np.float32(1.0)

    try:
        _ = _weighted_estimate_sim3_numba(src, tgt, weights)
        print(" - _weighted_estimate_sim3_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up _weighted_estimate_sim3_numba:", e)

    try:
        _ = _weighted_estimate_se3_numba(src, tgt, weights)
        print(" - _weighted_estimate_se3_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up _weighted_estimate_se3_numba:", e)

    try:
        _ = huber_loss_numba(residuals, delta)
        print(" - huber_loss_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up huber_loss_numba:", e)

    try:
        _ = compute_huber_weights_numba(residuals, delta)
        print(" - compute_huber_weights_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up compute_huber_weights_numba:", e)

    try:
        _ = compute_residuals_numba(tgt, src)
        print(" - compute_residuals_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up compute_residuals_numba:", e)

    try:
        _ = apply_transformation_numba(src, s, R, t)
        print(" - apply_transformation_numba warmed up.")
    except Exception as e:
        print(" ! Failed to warm up apply_transformation_numba:", e)

    print("Numba warm-up complete.\n")


# ===== Speed Up End =====

# ===== Scale precompute begin =====


def compute_scale_ransac(
    depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, max_samples=10000
):
    """
    Args:
        depth1: (n1, h, w)
        depth2: (n2, h, w)
        conf1: (n1, h, w)
        conf2: (n2, h, w)

    """

    depth1_flat = depth1.reshape(-1)
    depth2_flat = depth2.reshape(-1)
    conf1_flat = conf1.reshape(-1)
    conf2_flat = conf2.reshape(-1)

    conf_threshold = max(
        np.median(conf1_flat) * conf_threshold_ratio,
        np.median(conf2_flat) * conf_threshold_ratio,
        1e-6,
    )

    valid_mask = (
        (conf1_flat > conf_threshold)
        & (conf2_flat > conf_threshold)
        & (depth1_flat > 1e-3)
        & (depth2_flat > 1e-3)
        & (depth1_flat < 100)
        & (depth2_flat < 100)
    )

    if np.sum(valid_mask) < 100:
        print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
        return 1.0, 0.0

    valid_depth1 = depth1_flat[valid_mask]
    valid_depth2 = depth2_flat[valid_mask]

    if len(valid_depth1) > max_samples:
        indices = np.random.choice(len(valid_depth1), max_samples, replace=False)
        valid_depth1 = valid_depth1[indices]
        valid_depth2 = valid_depth2[indices]

    X = valid_depth2.reshape(-1, 1)
    y = valid_depth1

    base_estimator = LinearRegression(fit_intercept=False)
    ransac = RANSACRegressor(
        estimator=base_estimator,
        max_trials=1000,
        min_samples=max(10, len(X) // 100),
        residual_threshold=0.1,
        random_state=42,
    )

    ransac.fit(X, y)
    scale_factor = ransac.estimator_.coef_[0]
    inlier_mask = ransac.inlier_mask_
    inlier_ratio = np.sum(inlier_mask) / len(inlier_mask)

    print(f"RANSAC scale: {scale_factor:.6f}, inlier ratio: {inlier_ratio:.4f}")

    if 0.1 < scale_factor < 10.0:
        return scale_factor, inlier_ratio
    else:
        print(f"Warning: Unreasonable scale {scale_factor}, using 1.0")
        return 1.0, inlier_ratio


def compute_scale_weighted(
    depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, weight_power=2.0, robust_quantile=0.9
):
    """
    Args:
        depth1: (n1, h, w)
        depth2: (n2, h, w)
        conf1: (n1, h, w)
        conf2: (n2, h, w)
    """
    depth1_flat = depth1.reshape(-1)
    depth2_flat = depth2.reshape(-1)
    conf1_flat = conf1.reshape(-1)
    conf2_flat = conf2.reshape(-1)

    conf_threshold = max(
        np.median(conf1_flat) * conf_threshold_ratio,
        np.median(conf2_flat) * conf_threshold_ratio,
        1e-6,
    )

    valid_mask = (
        (conf1_flat > conf_threshold)
        & (conf2_flat > conf_threshold)
        & (depth1_flat > 1e-3)
        & (depth2_flat > 1e-3)
        & (depth1_flat < 100)
        & (depth2_flat < 100)
    )

    if np.sum(valid_mask) < 100:
        print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
        return 1.0, 0.0

    valid_depth1 = depth1_flat[valid_mask]
    valid_depth2 = depth2_flat[valid_mask]
    valid_conf1 = conf1_flat[valid_mask]
    valid_conf2 = conf2_flat[valid_mask]

    combined_weights = (valid_conf1 * valid_conf2) ** weight_power

    combined_weights = combined_weights / (np.sum(combined_weights) + 1e-8)

    ratios = valid_depth1 / (valid_depth2 + 1e-8)

    sorted_indices = np.argsort(ratios)
    sorted_ratios = ratios[sorted_indices]
    sorted_weights = combined_weights[sorted_indices]

    cumulative_weights = np.cumsum(sorted_weights)
    median_idx = np.searchsorted(cumulative_weights, 0.5)
    scale_median = sorted_ratios[median_idx] if median_idx < len(sorted_ratios) else 1.0

    quantile_idx = np.searchsorted(cumulative_weights, robust_quantile)
    scale_quantile = (
        sorted_ratios[quantile_idx] if quantile_idx < len(sorted_ratios) else scale_median
    )

    weight_entropy = -np.sum(combined_weights * np.log(combined_weights + 1e-8))
    max_entropy = np.log(len(combined_weights))
    confidence_score = 1.0 - (weight_entropy / max_entropy) if max_entropy > 0 else 0.0

    print(f"Weighted scale: {scale_quantile:.6f}, confidence: {confidence_score:.4f}")

    if 0.1 < scale_quantile < 10.0:
        return scale_quantile, confidence_score
    else:
        print(f"Warning: Unreasonable scale {scale_quantile}, using 1.0")
        return 1.0, confidence_score


def compute_chunk_scale_advanced(depth1, depth2, conf1, conf2, method="auto"):
    """
    method: 'auto', 'ransac', 'weighted'
    """
    if method == "ransac":
        scale, score = compute_scale_ransac(depth1, depth2, conf1, conf2)
        return scale, score, "ransac"

    elif method == "weighted":
        scale, score = compute_scale_weighted(depth1, depth2, conf1, conf2)
        return scale, score, "weighted"

    elif method == "auto":
        scale_ransac, inlier_ratio = compute_scale_ransac(depth1, depth2, conf1, conf2)
        scale_weighted, conf_score = compute_scale_weighted(depth1, depth2, conf1, conf2)

        ransac_quality = inlier_ratio
        weighted_quality = conf_score

        print(f"RANSAC quality: {ransac_quality:.4f}, Weighted quality: {weighted_quality:.4f}")

        if ransac_quality > 0.7 and weighted_quality > 0.7:
            # both method are good, we take both of them by average
            final_scale = (scale_ransac + scale_weighted) / 2
            final_method = "average"
        elif ransac_quality > weighted_quality:
            final_scale = scale_ransac
            final_method = "ransac"
        else:
            final_scale = scale_weighted
            final_method = "weighted"

        final_quality = max(ransac_quality, weighted_quality)
        return final_scale, final_quality, final_method


def precompute_scale_chunks_with_depth(
    chunk1_depth, chunk1_conf, chunk2_depth, chunk2_conf, method="auto"
):
    """
    Args:
        chunk1_depth: (n1, h, w)
        chunk1_conf: (n1, h, w)
        chunk2_depth: (n2, h, w)
        chunk2_conf: (n2, h, w)
        method: 'auto', 'ransac', 'weighted'
    """

    scale_factor, quality_score, method_used = compute_chunk_scale_advanced(
        chunk1_depth, chunk2_depth, chunk1_conf, chunk2_conf, method
    )

    print(f"Final scale: {scale_factor:.6f}, quality: {quality_score:.4f}, method: {method_used}")

    return scale_factor, quality_score, method_used


# ===== Scale precompute end =====


def weighted_align_point_maps(
    point_map1, conf1, point_map2, conf2, conf_threshold, config, precompute_scale=None
):
    """point_map2 -> point_map1"""
    b1, _, _, _ = point_map1.shape
    b2, _, _, _ = point_map2.shape
    b = min(b1, b2)

    if precompute_scale is not None:  # meaning we are using align method 'scale+se3'
        point_map2 *= precompute_scale

    aligned_points1 = []
    aligned_points2 = []
    confidence_weights = []

    for i in range(b):
        mask1 = conf1[i] > conf_threshold
        mask2 = conf2[i] > conf_threshold
        valid_mask = mask1 & mask2

        idx = np.where(valid_mask)
        if len(idx[0]) == 0:
            continue

        pts1 = point_map1[i][idx]
        pts2 = point_map2[i][idx]

        combined_conf = np.sqrt(conf1[i][idx] * conf2[i][idx])

        aligned_points1.append(pts1)
        aligned_points2.append(pts2)
        confidence_weights.append(combined_conf)

    if len(aligned_points1) == 0:
        raise ValueError("No matching point pairs were found!")

    all_pts1 = np.concatenate(aligned_points1, axis=0)
    all_pts2 = np.concatenate(aligned_points2, axis=0)
    all_weights = np.concatenate(confidence_weights, axis=0)

    print(f"The number of corresponding points matched: {all_pts1.shape[0]}")

    if config["Model"]["align_lib"] == "numba":
        s, R, t = robust_weighted_estimate_sim3_numba(
            all_pts2,
            all_pts1,
            all_weights,
            delta=config["Model"]["IRLS"]["delta"],
            max_iters=config["Model"]["IRLS"]["max_iters"],
            tol=eval(config["Model"]["IRLS"]["tol"]),
            align_method=config["Model"]["align_method"],
        )
    elif config["Model"]["align_lib"] == "numpy":  # numpy
        s, R, t = robust_weighted_estimate_sim3(
            all_pts2,
            all_pts1,
            all_weights,
            delta=config["Model"]["IRLS"]["delta"],
            max_iters=config["Model"]["IRLS"]["max_iters"],
            tol=eval(config["Model"]["IRLS"]["tol"]),
            align_method=config["Model"]["align_method"],
        )
    elif config["Model"]["align_lib"] == "torch":  # torch
        s, R, t = robust_weighted_estimate_sim3_torch(
            all_pts2,
            all_pts1,
            all_weights,
            delta=config["Model"]["IRLS"]["delta"],
            max_iters=config["Model"]["IRLS"]["max_iters"],
            tol=eval(config["Model"]["IRLS"]["tol"]),
            align_method=config["Model"]["align_method"],
        )
    elif config["Model"]["align_lib"] == "triton":  # triton
        s, R, t = robust_weighted_estimate_sim3_triton(
            all_pts2,
            all_pts1,
            all_weights,
            delta=config["Model"]["IRLS"]["delta"],
            max_iters=config["Model"]["IRLS"]["max_iters"],
            tol=eval(config["Model"]["IRLS"]["tol"]),
            align_method=config["Model"]["align_method"],
        )
    else:
        raise ValueError(f"Unknown align_lib: {config['Model']['align_lib']}")

    if precompute_scale is not None:  # meaning we are using align method 'scale+se3'
        # we need this precompute_scale for loop align
        s = precompute_scale

    mean_error = compute_alignment_error(
        point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
    )
    print(f"Mean error: {mean_error}")

    return s, R, t
