# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model inference module for Depth Anything 3 Gradio app.

This module handles all model-related operations including inference,
data processing, and result preparation.
"""

import glob
import os
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch

from depth_anything_3.api import DepthAnything3
from depth_anything_3.utils.memory import cleanup_cuda_memory
from depth_anything_3.utils.export.glb import export_to_glb
from depth_anything_3.utils.export.gs import export_to_gs_video


class ModelInference:
    """
    Handles model inference and data processing for Depth Anything 3.
    """

    def __init__(self):
        """Initialize the model inference handler."""
        self.model = None

    def initialize_model(self, device: str = "cuda") -> None:
        """
        Initialize the DepthAnything3 model.

        Args:
            device: Device to load the model on
        """
        if self.model is None:
            # Get model directory from environment variable or use default
            model_dir = os.environ.get(
                "DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
            )
            self.model = DepthAnything3.from_pretrained(model_dir)
            self.model = self.model.to(device)
        else:
            self.model = self.model.to(device)

        self.model.eval()

    def run_inference(
        self,
        target_dir: str,
        filter_black_bg: bool = False,
        filter_white_bg: bool = False,
        process_res_method: str = "upper_bound_resize",
        show_camera: bool = True,
        save_percentage: float = 30.0,
        num_max_points: int = 1_000_000,
        infer_gs: bool = False,
        ref_view_strategy: str = "saddle_balanced",
        gs_trj_mode: str = "extend",
        gs_video_quality: str = "high",
    ) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
        """
        Run DepthAnything3 model inference on images.

        Args:
            target_dir: Directory containing images
            filter_black_bg: Whether to filter black background
            filter_white_bg: Whether to filter white background
            process_res_method: Method for resizing input images
            show_camera: Whether to show camera in 3D view
            save_percentage: Percentage of points to save (0-100)
            num_max_points: Maximum number of points in point cloud
            infer_gs: Whether to infer 3D Gaussian Splatting
            ref_view_strategy: Reference view selection strategy
            gs_trj_mode: Trajectory mode for 3DGS
            gs_video_quality: Video quality for 3DGS

        Returns:
            Tuple of (prediction, processed_data)
        """
        print(f"Processing images from {target_dir}")

        # Device check
        device = "cuda" if torch.cuda.is_available() else "cpu"
        device = torch.device(device)

        # Initialize model if needed
        self.initialize_model(device)

        # Get image paths
        print("Loading images...")
        image_folder_path = os.path.join(target_dir, "images")
        all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))

        # Filter for image files
        image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
        all_image_paths = [
            path
            for path in all_image_paths
            if any(path.lower().endswith(ext) for ext in image_extensions)
        ]

        print(f"Found {len(all_image_paths)} images")
        print(f"All image paths: {all_image_paths}")

        # Use sorted image order (reference view will be selected automatically)
        image_paths = all_image_paths
        print(f"Reference view selection strategy: {ref_view_strategy}")

        if len(image_paths) == 0:
            raise ValueError("No images found. Check your upload.")

        # Map UI options to actual method names
        method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
        actual_method = method_mapping.get(process_res_method, "upper_bound_crop")

        # Run model inference
        print(f"Running inference with method: {actual_method}")
        with torch.no_grad():
            prediction = self.model.inference(
                image_paths,
                export_dir=None,
                process_res_method=actual_method,
                infer_gs=infer_gs,
                ref_view_strategy=ref_view_strategy,
            )
        # num_max_points: int = 1_000_000,
        export_to_glb(
            prediction,
            filter_black_bg=filter_black_bg,
            filter_white_bg=filter_white_bg,
            export_dir=target_dir,
            show_cameras=show_camera,
            conf_thresh_percentile=save_percentage,
            num_max_points=int(num_max_points),
        )

        # export to gs video if needed
        if infer_gs:
            mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
            print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
            export_to_gs_video(
                prediction,
                export_dir=target_dir,
                chunk_size=4,
                trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
                enable_tqdm=True,
                vis_depth="hcat",
                video_quality=gs_video_quality,
            )

        # Save predictions.npz for caching metric depth data
        self._save_predictions_cache(target_dir, prediction)

        # Process results
        processed_data = self._process_results(target_dir, prediction, image_paths)

        # Clean up using centralized memory utilities for consistency with backend
        cleanup_cuda_memory()

        return prediction, processed_data

    def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
        """
        Save predictions data to predictions.npz for caching.

        Args:
            target_dir: Directory to save the cache
            prediction: Model prediction object
        """
        try:
            output_file = os.path.join(target_dir, "predictions.npz")

            # Build save dict with prediction data
            save_dict = {}

            # Save processed images if available
            if prediction.processed_images is not None:
                save_dict["images"] = prediction.processed_images

            # Save depth data
            if prediction.depth is not None:
                save_dict["depths"] = np.round(prediction.depth, 6)

            # Save confidence if available
            if prediction.conf is not None:
                save_dict["conf"] = np.round(prediction.conf, 2)

            # Save camera parameters
            if prediction.extrinsics is not None:
                save_dict["extrinsics"] = prediction.extrinsics
            if prediction.intrinsics is not None:
                save_dict["intrinsics"] = prediction.intrinsics

            # Save to file
            np.savez_compressed(output_file, **save_dict)
            print(f"Saved predictions cache to: {output_file}")

        except Exception as e:
            print(f"Warning: Failed to save predictions cache: {e}")

    def _process_results(
        self, target_dir: str, prediction: Any, image_paths: list
    ) -> Dict[int, Dict[str, Any]]:
        """
        Process model results into structured data.

        Args:
            target_dir: Directory containing results
            prediction: Model prediction object
            image_paths: List of input image paths

        Returns:
            Dictionary containing processed data for each view
        """
        processed_data = {}

        # Read generated depth visualization files
        depth_vis_dir = os.path.join(target_dir, "depth_vis")

        if os.path.exists(depth_vis_dir):
            depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
            for i, depth_file in enumerate(depth_files):
                # Use processed images directly from API
                processed_image = None
                if prediction.processed_images is not None and i < len(
                    prediction.processed_images
                ):
                    processed_image = prediction.processed_images[i]

                processed_data[i] = {
                    "depth_image": depth_file,
                    "image": processed_image,
                    "original_image_path": image_paths[i] if i < len(image_paths) else None,
                    "depth": prediction.depth[i] if i < len(prediction.depth) else None,
                    "intrinsics": (
                        prediction.intrinsics[i]
                        if prediction.intrinsics is not None and i < len(prediction.intrinsics)
                        else None
                    ),
                    "mask": None,  # No mask information available
                }

        return processed_data

    # cleanup() removed: call cleanup_cuda_memory() directly where needed.
