# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model loading and state dict conversion utilities.
"""

from typing import Dict, Tuple
import torch

from depth_anything_3.utils.logger import logger


def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convert general model state dict to match current model architecture.

    Args:
        state_dict: Original state dictionary

    Returns:
        Converted state dictionary
    """
    # Replace module prefixes
    state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()}
    state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()}

    # Remove camera token if present
    if "model.backbone.pretrained.camera_token" in state_dict:
        del state_dict["model.backbone.pretrained.camera_token"]

    # Replace camera token naming
    state_dict = {
        k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items()
    }

    # Replace head naming
    state_dict = {
        k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v
        for k, v in state_dict.items()
    }
    state_dict = {
        k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items()
    }
    state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()}
    state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()}
    state_dict = {
        k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items()
    }

    # Replace output naming
    state_dict = {
        k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v
        for k, v in state_dict.items()
    }
    state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()}

    # Update GS-DPT head naming and value
    state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()}

    return state_dict


def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convert metric model state dict to match current model architecture.

    Args:
        state_dict: Original metric state dictionary

    Returns:
        Converted state dictionary
    """
    # Add module prefix for metric models
    state_dict = {"module." + k: v for k, v in state_dict.items()}
    return convert_general_state_dict(state_dict)


def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]:
    """
    Load pretrained weights for a single model.

    Args:
        model: Model instance to load weights into
        model_path: Path to the pretrained weights
        is_metric: Whether this is a metric model

    Returns:
        Tuple of (missed_keys, unexpected_keys)
    """
    state_dict = torch.load(model_path, map_location="cpu")

    if is_metric:
        state_dict = convert_metric_state_dict(state_dict)
    else:
        state_dict = convert_general_state_dict(state_dict)

    missed, unexpected = model.load_state_dict(state_dict, strict=False)
    logger.info("Missed keys:", missed)
    logger.info("Unexpected keys:", unexpected)

    return missed, unexpected


def load_pretrained_nested_weights(
    model, main_model_path: str, metric_model_path: str
) -> Tuple[list, list]:
    """
    Load pretrained weights for a nested model with both main and metric branches.

    Args:
        model: Nested model instance
        main_model_path: Path to main model weights
        metric_model_path: Path to metric model weights

    Returns:
        Tuple of (missed_keys, unexpected_keys)
    """
    # Load main model weights
    state_dict0 = torch.load(main_model_path, map_location="cpu")
    state_dict0 = convert_general_state_dict(state_dict0)
    state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()}

    # Load metric model weights
    state_dict1 = torch.load(metric_model_path, map_location="cpu")
    state_dict1 = convert_metric_state_dict(state_dict1)
    state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()}

    # Combine state dictionaries
    combined_state_dict = state_dict0.copy()
    combined_state_dict.update(state_dict1)

    missed, unexpected = model.load_state_dict(combined_state_dict, strict=False)

    print("Missed keys:", missed)
    print("Unexpected keys:", unexpected)

    return missed, unexpected
