# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

# ------------------------------------------------------------------------
# Plain-DETR
# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia.
# Licensed under The MIT License [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
Utilities for bounding box manipulation and GIoU.
"""
import numpy as np
import torch


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


def box_xyxy_to_cxcywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)


def delta2bbox(
    proposals, deltas, max_shape=None, wh_ratio_clip=16 / 1000, clip_border=True, add_ctr_clamp=False, ctr_clamp=32
):
    dxy = deltas[..., :2]
    dwh = deltas[..., 2:]

    # Compute width/height of each roi
    pxy = proposals[..., :2]
    pwh = proposals[..., 2:]

    dxy_wh = pwh * dxy

    max_ratio = np.abs(np.log(wh_ratio_clip))
    if add_ctr_clamp:
        dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
        dwh = torch.clamp(dwh, max=max_ratio)
    else:
        dwh = dwh.clamp(min=-max_ratio, max=max_ratio)

    gxy = pxy + dxy_wh
    gwh = pwh * dwh.exp()
    x1y1 = gxy - (gwh * 0.5)
    x2y2 = gxy + (gwh * 0.5)
    bboxes = torch.cat([x1y1, x2y2], dim=-1)
    if clip_border and max_shape is not None:
        bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
        bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
    return bboxes


def bbox2delta(proposals, gt, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
    # hack for matcher
    if proposals.size() != gt.size():
        proposals = proposals[:, None]
        gt = gt[None]

    proposals = proposals.float()
    gt = gt.float()
    px, py, pw, ph = proposals.unbind(-1)
    gx, gy, gw, gh = gt.unbind(-1)

    dx = (gx - px) / (pw + 0.1)
    dy = (gy - py) / (ph + 0.1)
    dw = torch.log(gw / (pw + 0.1))
    dh = torch.log(gh / (ph + 0.1))
    deltas = torch.stack([dx, dy, dw, dh], dim=-1)

    # avoid unnecessary sync point if not needed
    if means != (0.0, 0.0, 0.0, 0.0) or stds != (1.0, 1.0, 1.0, 1.0):
        means = deltas.new_tensor(means).unsqueeze(0)
        stds = deltas.new_tensor(stds).unsqueeze(0)
        deltas = deltas.sub_(means).div_(stds)

    return deltas
