# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

import logging
from enum import Enum
from typing import Sequence, Union

from dinov3.eval.dense.depth.models.embed import CenterPadding, StretchToMultiple
from torch import Tensor, nn

logger = logging.getLogger("dinov3")


class BackboneLayersSet(Enum):
    # Set of intermediate layers to take from the backbone
    LAST = "LAST"  # extracting only the last layer
    FOUR_LAST = "FOUR_LAST"  # extracting the last 4 layers
    FOUR_EVEN_INTERVALS = "FOUR_EVEN_INTERVALS"  # extracting outputs every 1/4 of the total number of blocks


def _get_backbone_out_indices(
    model: nn.Module,
    backbone_out_layers: Union[list[int], BackboneLayersSet] = BackboneLayersSet.FOUR_EVEN_INTERVALS,
):
    """
    Get indices for output layers of the ViT backbone. For now there are 3 options available:
    BackboneLayersSet.LAST : only extract the last layer, used in segmentation tasks with a bn head.
    BackboneLayersSet.FOUR_LAST : extract the last 4 layers, used in segmentation (multiscale setting)
    BackboneLayersSet.FOUR_EVEN_INTERVALS : extract outputs every 1/4 of the total number of blocks
    Reference outputs in 'FOUR_EVEN_INTERVALS' mode :
    ViT/S (12 blocks): [2, 5, 8, 11]
    ViT/B (12 blocks): [2, 5, 8, 11]
    ViT/L (24 blocks): [5, 11, 17, 23] (correct), [4, 11, 17, 23] (incorrect)
    ViT/g (40 blocks): [9, 19, 29, 39]
    """
    n_blocks = getattr(model, "n_blocks", 1)
    if isinstance(backbone_out_layers, list):
        out_indices = backbone_out_layers

    if backbone_out_layers == BackboneLayersSet.LAST:
        out_indices = [n_blocks - 1]
    elif backbone_out_layers == BackboneLayersSet.FOUR_LAST:
        out_indices = [i for i in range(n_blocks - 4, n_blocks)]
    elif backbone_out_layers == BackboneLayersSet.FOUR_EVEN_INTERVALS:
        # XXX: Force (incorrect) out indices for backward-compatibility (ViT/L only)
        if n_blocks == 24:
            out_indices = [4, 11, 17, 23]
        else:
            out_indices = [i * (n_blocks // 4) - 1 for i in range(1, 5)]
    assert all([out_index < n_blocks for out_index in out_indices])
    return out_indices


class PatchSizeAdaptationStrategy(Enum):
    CENTER_PADDING = "center_padding"
    STRETCH = "stretch"
    NO_ADAPTATION = "never"


class DinoVisionTransformerWrapper(nn.Module):
    """Vision Transformer."""

    def __init__(
        self,
        backbone_model: nn.Module,
        backbone_out_layers: Union[str, list[int]],
        use_backbone_norm: bool = False,
        adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING,
    ):
        super().__init__()

        self.final_norm = use_backbone_norm
        self.backbone = backbone_model
        self.backbone_out_indices = _get_backbone_out_indices(
            self.backbone,
            backbone_out_layers=(
                backbone_out_layers if isinstance(backbone_out_layers, list) else BackboneLayersSet(backbone_out_layers)
            ),
        )

        # If the backbone does not define embed_dims, use [embed_dim] * n_blocks
        try:
            embed_dims = self.backbone.embed_dims
        except AttributeError:
            embed_dim = self.backbone.embed_dim
            n_blocks = self.backbone.n_blocks
            logger.warning(f"Backbone does not define embed_dims, using {[embed_dim] * n_blocks=} instead")
            embed_dims = [embed_dim] * n_blocks
        self.embed_dims: Sequence[int] = [embed_dims[idx] for idx in self.backbone_out_indices]

        # How to adapt input images to the patch size of the model?
        try:
            input_pad_size = self.backbone.input_pad_size
        except AttributeError:
            patch_size = self.backbone.patch_size
            logger.warning(f"Backbone does not define input_pad_size, using {patch_size=} instead")
            input_pad_size = patch_size
        if adapt_to_patch_size is PatchSizeAdaptationStrategy.CENTER_PADDING:
            self.patch_size_adapter = CenterPadding(input_pad_size)
        elif adapt_to_patch_size is PatchSizeAdaptationStrategy.STRETCH:
            self.patch_size_adapter = StretchToMultiple(input_pad_size)
        elif adapt_to_patch_size is PatchSizeAdaptationStrategy.NO_ADAPTATION:
            self.patch_size_adapter = nn.Identity()
        else:
            raise ValueError(f"Unknown value {adapt_to_patch_size=}")

        # Freeze backbone
        self.backbone.requires_grad_(False)

    def forward(
        self,
        x: Tensor,  # [B, rgb, H, W]
    ) -> list[tuple[Tensor, Tensor]]:
        x = self.patch_size_adapter(x)
        outputs = self.backbone.get_intermediate_layers(
            x,
            n=self.backbone_out_indices,
            reshape=True,
            return_class_token=True,
            norm=self.final_norm,
        )  # List of (patch feats [B, C, h, w], class token [B, C])
        return outputs
