# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import os
from typing import Optional

import torch
from peft import LoraConfig, set_peft_model_state_dict

from cosmos_predict2._src.imaginaire.config import Config
from cosmos_predict2._src.imaginaire.flags import INTERNAL, SMOKE
from cosmos_predict2._src.imaginaire.lazy_config import instantiate
from cosmos_predict2._src.imaginaire.model import ImaginaireModel
from cosmos_predict2._src.imaginaire.utils import distributed, log, misc
from cosmos_predict2._src.imaginaire.utils.config_helper import get_config_module, override
from cosmos_predict2._src.imaginaire.utils.easy_io import easy_io
from cosmos_predict2._src.imaginaire.utils.fsdp_helper import hsdp_device_mesh
from cosmos_predict2._src.predict2.checkpointer.dcp import (
    DefaultLoadPlanner,
    DistributedCheckpointer,
    ModelWrapper,
    dcp_load_state_dict,
)


def load_model_from_checkpoint(
    experiment_name,
    s3_checkpoint_dir,
    config_file="cosmos_predict2/_src/predict2/configs/video2world/config.py",
    enable_fsdp=False,
    load_ema_to_reg=False,
    instantiate_ema=True,
    seed=0,
    local_cache_dir=None,
    override_cache: bool = False,
    experiment_opts: Optional[list[str]] = None,
    skip_load_model: bool = False,
    adapter_checkpoint_paths: Optional[list[str]] = None,
    cache_text_encoder: bool = False,
    to_device: Optional[str] = None,
):
    """
    Load model from checkpoint with optional multi-adapter support.

    Args:
        experiment_name: experiment name
        s3_checkpoint_dir: s3 path to iteration_model
        config_file: config file path
        enable_fsdp: enable fsdp
        load_ema_to_reg: load ema as regular model
        instantiate_ema: whether to instantiate EMA
        seed: random seed
        local_cache_dir: local cache directory, if None, do not cache
        override_cache: override cache, if True, override cache if local cache exists
        experiment_opts: experiment options
        skip_load_model: skip loading model weights
        adapter_checkpoint_paths: list of checkpoint paths for loading multiple adapters
            Supports both .pt and DCP checkpoint formats (auto-detected by file extension).
            Example:
                adapter_checkpoint_paths=[
                    "s3://bucket/exp1/checkpoints/model.pt",  # .pt format
                    "s3://bucket/exp2/checkpoints"  # DCP format
                ]
        cache_text_encoder: cache text encoder, if True, cache text encoder. This is default to False to avoid race condition if multiple nodes are running inference concurrently (e.g., running inference pipeline).

    Returns:
        model: loaded model
        config: config object
    """
    if experiment_opts is None:
        experiment_opts = []
    config_module = get_config_module(config_file)
    config = importlib.import_module(config_module).make_config()
    config = override(config, ["--", f"experiment={experiment_name}"] + experiment_opts)

    # Override checkpoint path if provided
    if s3_checkpoint_dir:
        log.info(f"Overriding config checkpoint path with: {s3_checkpoint_dir}")
        config.checkpoint.load_path = str(s3_checkpoint_dir)

    if load_ema_to_reg:
        config.model.config.ema.enabled = False

    if instantiate_ema is False and config.model.config.ema.enabled:
        config.model.config.ema.enabled = False

    # Check that the config is valid
    config.validate()
    # Freeze the config so developers don't change it during training.
    config.freeze()  # type: ignore
    misc.set_random_seed(seed=seed, by_rank=True)
    # Initialize cuDNN.
    torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
    torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
    # Floating-point precision settings.
    torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True

    log.info(f"Loading model from {s3_checkpoint_dir}")

    if not enable_fsdp:
        # disable fsdp
        config.model.config.fsdp_shard_size = 1
    with misc.timer("instantiate model"):
        model = instantiate(config.model)
        if to_device is not None:
            model.to(torch.device(to_device))
        # Convert the model parameters to bf16
        model.on_train_start()

    if not skip_load_model:
        # Handle different adapter loading scenarios
        if adapter_checkpoint_paths:
            # First load base model
            model = load_model_state_dict_from_checkpoint(
                model, config, s3_checkpoint_dir, load_ema_to_reg, local_cache_dir, override_cache
            )
            # Then load additional adapters from different checkpoints
            log.info(f"Loading {len(adapter_checkpoint_paths)} adapters from different checkpoints")
            adapter_names = [f"adapter_{i}" for i in range(len(adapter_checkpoint_paths))]

            for adapter_name, checkpoint_path in zip(adapter_names, adapter_checkpoint_paths):
                log.info(f"Loading adapter '{adapter_name}' from {checkpoint_path}")
                lora_config = LoraConfig(
                    r=model.config.lora_rank,
                    lora_alpha=model.config.lora_alpha,
                    init_lora_weights=model.config.init_lora_weights,
                    target_modules=[module.strip() for module in model.config.lora_target_modules.split(",")],
                    use_dora=model.config.use_dora,
                )
                model.net.add_adapter(adapter_name, lora_config)

                if checkpoint_path.endswith(".pt"):
                    # adapter_state_dict = easy_io.load(checkpoint_path)
                    adapter_state_dict = torch.load(checkpoint_path, map_location="cpu")
                    old_keys = list(adapter_state_dict.keys())
                    for key in old_keys:
                        if "lora_" in key:
                            net_prefix = "net." if load_ema_to_reg else "net_ema."
                            new_key = key.replace(net_prefix, "base_model.model.").replace("default.", "")
                            adapter_state_dict[new_key] = adapter_state_dict.pop(key)
                    load_result = set_peft_model_state_dict(model.net, adapter_state_dict, adapter_name=adapter_name)
                    # for key in load_result.missing_keys:
                    #     log.warning(f"Missing key: {key}")
                    for key in load_result.unexpected_keys:
                        log.warning(f"Unexpected key: {key}")
                        assert False, "Unexpected key found"
                else:
                    log.info(f"Loading adapter '{adapter_name}' from s3 {checkpoint_path}")
                    if checkpoint_path.rstrip("/").endswith("/model"):
                        cur_key_ckpt_full_path = checkpoint_path
                    else:
                        cur_key_ckpt_full_path = os.path.join(checkpoint_path, "model")

                    checkpointer = DistributedCheckpointer(
                        config.checkpoint, config.job, callbacks=None, disable_async=True
                    )

                    _model_wrapper = ModelWrapper(model, load_ema_to_reg=load_ema_to_reg)
                    mapping_keys = {
                        adapter_name + ".": "default.",
                    }
                    _state_dict = _model_wrapper.state_dict(mapping_keys=mapping_keys)
                    storage_reader = checkpointer.get_storage_reader(cur_key_ckpt_full_path)
                    load_planner = DefaultLoadPlanner(allow_partial_load=True)
                    dcp_load_state_dict(_state_dict, storage_reader, load_planner)
                    _model_wrapper.load_state_dict(_state_dict)

                log.info(f"Loaded adapter '{adapter_name}'")

            # Activate first adapter
            model.net.set_adapter(adapter_names[0])
            log.info(f"Activated adapter: {adapter_names[0]}")
        else:
            # Load normally (single checkpoint)
            model = load_model_state_dict_from_checkpoint(
                model, config, s3_checkpoint_dir, load_ema_to_reg, local_cache_dir, override_cache
            )

    return model, config


def load_model_state_dict_from_checkpoint(
    model,
    config,
    s3_checkpoint_dir,
    load_ema_to_reg=False,
    local_cache_dir=None,
    override_cache: bool = False,
):
    if s3_checkpoint_dir is not None:
        s3_checkpoint_dir = str(s3_checkpoint_dir)
    checkpoint_format = "pt" if s3_checkpoint_dir.endswith(".pt") else "dcp"
    if s3_checkpoint_dir.startswith("s3:"):
        if checkpoint_format == "pt":
            cur_key_ckpt_full_path = s3_checkpoint_dir
        elif s3_checkpoint_dir.rstrip("/").endswith("/model"):
            cur_key_ckpt_full_path = s3_checkpoint_dir
        else:
            cur_key_ckpt_full_path = os.path.join(s3_checkpoint_dir, "model")
    else:
        cur_key_ckpt_full_path = s3_checkpoint_dir

    from cosmos_predict2._src.imaginaire.utils.checkpoint_db import get_checkpoint_path

    load_from_local = True
    local_s3_ckpt_fp = get_checkpoint_path(cur_key_ckpt_full_path)

    if SMOKE:
        return model

    if load_from_local:
        # Load on rank0 only and broadcast
        if distributed.is_rank0():
            log.info(f"Loading model cached locally from {local_s3_ckpt_fp}")
            local_state_dict = easy_io.load(local_s3_ckpt_fp, weights_only=INTERNAL)

            # Handle LoRA key mapping if the model uses LoRA and checkpoint is in .pt format
            if (
                hasattr(model, "config")
                and hasattr(model.config, "use_lora")
                and model.config.use_lora
                and checkpoint_format == "pt"
            ):
                log.info("Model uses LoRA, mapping checkpoint keys to model keys with base_layer...")
                mapped_state_dict = {}
                mapped_keys = []
                missing_keys = []

                # Get current model state dict to understand what keys are expected
                model_state_dict = model.state_dict()

                for model_key in model_state_dict.keys():
                    if "base_layer." in model_key or "base_model.model." in model_key:
                        # This is a LoRA layer - map from checkpoint key (without base_layer)
                        checkpoint_key = model_key.replace("base_layer.", "").replace("base_model.model.", "")
                        if checkpoint_key in local_state_dict:
                            mapped_state_dict[model_key] = local_state_dict[checkpoint_key]
                            mapped_keys.append(f"{checkpoint_key} -> {model_key}")
                        else:
                            missing_keys.append(model_key)
                    elif model_key in local_state_dict:
                        # Direct mapping for non-LoRA keys
                        mapped_state_dict[model_key] = local_state_dict[model_key]
                    else:
                        missing_keys.append(model_key)

                if mapped_keys:
                    log.info(f"Mapped {len(mapped_keys)} LoRA keys from checkpoint to model (showing first 5):")
                    for mapped_key in mapped_keys[:5]:
                        log.info(f"  {mapped_key}")
                if missing_keys:
                    log.warning(f"Missing keys in checkpoint: {missing_keys[:10]}... (showing first 10)")

                local_state_dict = mapped_state_dict

            # `strict=False` is needed to avoid errors: `Skipping key ... introduced by TransformerEngine for FP8 in the checkpoint.`
            load_info = model.load_state_dict(local_state_dict, strict=False)
            log.info(f"Checkpoint weights loaded from {local_s3_ckpt_fp}: {load_info}")
            if load_info is None:
                log.info("Checkpoint weights loaded successfully (strict=True).")
            elif not load_info.missing_keys and not load_info.unexpected_keys:
                log.info("Checkpoint weights loaded successfully (all keys matched).")
            else:
                log.warning("Checkpoint weights loaded; review missing_keys/unexpected_keys above.")

        # Synchronize model states from rank 0 to all other ranks
        # Skip EMA parameters and buffers to avoid OOM - they are on CPU now, and will be moved to CUDA and synced via copy from main model after FSDP
        params_and_buffers_to_ignore = set()
        if hasattr(model, "net_ema") and model.net_ema is not None:
            # Add all parameters
            for param_name, _ in model.net_ema.named_parameters():
                params_and_buffers_to_ignore.add(f"net_ema.{param_name}")
            # Add all buffers (e.g., running_mean, running_var in BatchNorm)
            for buffer_name, _ in model.net_ema.named_buffers():
                params_and_buffers_to_ignore.add(f"net_ema.{buffer_name}")
            log.info(
                f"Skipping sync for {len(params_and_buffers_to_ignore)} EMA parameters and buffers to avoid OOM during initialization"
            )

        distributed.sync_model_states(model, src=0, params_and_buffers_to_ignore=params_and_buffers_to_ignore)
    else:
        log.info(f"Loading model from s3 {s3_checkpoint_dir}")

        checkpointer = DistributedCheckpointer(config.checkpoint, config.job, callbacks=None, disable_async=True)

        _model_wrapper = ModelWrapper(
            model,
            load_ema_to_reg=load_ema_to_reg if checkpoint_format == "dcp" else False,
        )
        _state_dict = _model_wrapper.state_dict()
        if checkpoint_format == "dcp":
            storage_reader = checkpointer.get_storage_reader(cur_key_ckpt_full_path)
            load_planner = DefaultLoadPlanner(allow_partial_load=True)
            dcp_load_state_dict(_state_dict, storage_reader, load_planner)
            _model_wrapper.load_state_dict(_state_dict)
        else:  # pt format - load on rank0 only and broadcast
            if distributed.is_rank0():
                if "s3://" in s3_checkpoint_dir:
                    pt_state_dict = easy_io.load(
                        s3_checkpoint_dir,
                        backend_args={
                            "backend": "s3",
                            "s3_credential_path": "credentials/s3_training.secret",
                        },
                    )
                else:
                    pt_state_dict = easy_io.load(s3_checkpoint_dir)
                # Handle different .pt checkpoint formats
                if "model" in pt_state_dict:
                    # Checkpoint contains multiple components (model, optimizer, etc.)
                    model_state = pt_state_dict["model"]
                elif "state_dict" in pt_state_dict:
                    # Alternative format
                    model_state = pt_state_dict["state_dict"]
                else:
                    # Assume the checkpoint is the state dict itself
                    model_state = pt_state_dict
                # Update the state dict with loaded weights
                # Handle potential key mismatches
                missing_keys = []
                unexpected_keys = []
                for key in _state_dict.keys():
                    if key in model_state:
                        _state_dict[key] = model_state[key]
                    else:
                        missing_keys.append(key)

                for key in model_state.keys():
                    if key not in _state_dict:
                        unexpected_keys.append(key)

                if missing_keys:
                    log.warning(f"Missing keys in checkpoint: {missing_keys[:10]}... (showing first 10)")
                if unexpected_keys:
                    log.warning(f"Unexpected keys in checkpoint: {unexpected_keys[:10]}... (showing first 10)")

                # only load on rank0
                _model_wrapper.load_state_dict(_state_dict)

            # Synchronize model states from rank 0 to all other ranks
            distributed.sync_model_states(model, src=0)

        # Cache the model state dict only on rank0 to be consistent with loading
        if local_cache_dir is not None and distributed.is_rank0():
            log.info(f"Caching model state dict to {local_s3_ckpt_fp}")
            easy_io.dump(model.state_dict(), local_s3_ckpt_fp)

    # Clear unused reserved memory from fp32
    torch.cuda.empty_cache()
    return model


def create_model_from_consolidated_checkpoint_with_fsdp(config: Config) -> ImaginaireModel:
    """
    Instantiate a model, load weights from a consolidated checkpoint, and initialize FSDP if required.

    Args:
        config: The configuration object for the experiment.

    Returns:
        model: The loaded and (optionally) FSDP-wrapped model.
    """
    # To avoid DTensor issues, load the model from a consolidated checkpoint in Tensor format before applying FSDP.
    fsdp_shard_size = config.model.config.fsdp_shard_size
    config.model.config.fsdp_shard_size = 1  # Set to 1 to disable FSDP during model instantiation.
    model = instantiate(config.model).cuda()
    # DCP checkpointer does not support loading from a consolidated checkpoint, so we support it here.
    model = load_model_state_dict_from_checkpoint(
        model=model,
        config=config,
        s3_checkpoint_dir=config.checkpoint.load_path,
        load_ema_to_reg=config.checkpoint.load_ema_to_reg,
    )
    # If FSDP is enabled, apply FSDP to the model.
    if fsdp_shard_size > 1:
        config.model.config.fsdp_shard_size = fsdp_shard_size
        fsdp_device_mesh = hsdp_device_mesh(
            sharding_group_size=fsdp_shard_size,
        )
        if hasattr(model, "apply_fsdp") and callable(model.apply_fsdp):
            model.apply_fsdp(fsdp_device_mesh)
        else:
            raise AttributeError(
                "Model does not implement 'apply_fsdp'. Please implement this method to enable FSDP after consolidated checkpoint loading."
            )

    return model
