from __future__ import annotations

import dataclasses
import enum
import gc
import itertools
import operator
import sys
import textwrap
import timeit
from collections.abc import Iterable, Sized
from importlib import metadata
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

# Do not import additional dependencies at top-level
if TYPE_CHECKING:
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib import axes, figure

import pydantic

PYTHON_VERSION = '.'.join(map(str, sys.version_info))
PYDANTIC_VERSION = metadata.version('pydantic')


# New implementation of pydantic.BaseModel.__eq__ to test


class OldImplementationModel(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            return (
                self_type == other_type
                and self.__dict__ == other.__dict__
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            )
        else:
            return NotImplemented  # delegate to the other item in the comparison


class DictComprehensionEqModel(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            field_names = type(self).model_fields.keys()

            return (
                self_type == other_type
                and ({k: self.__dict__[k] for k in field_names} == {k: other.__dict__[k] for k in field_names})
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            )
        else:
            return NotImplemented  # delegate to the other item in the comparison


class ItemGetterEqModel(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            model_fields = type(self).model_fields.keys()
            getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None

            return (
                self_type == other_type
                and getter(self.__dict__) == getter(other.__dict__)
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            )
        else:
            return NotImplemented  # delegate to the other item in the comparison


class ItemGetterEqModelFastPath(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            # Perform common checks first
            if not (
                self_type == other_type
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            ):
                return False

            # Fix GH-7444 by comparing only pydantic fields
            # We provide a fast-path for performance: __dict__ comparison is *much* faster
            # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks
            if self.__dict__ == other.__dict__:
                # If the check above passes, then pydantic fields are equal, we can return early
                return True
            else:
                # Else, we need to perform a more detailed, costlier comparison
                model_fields = type(self).model_fields.keys()
                getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None
                return getter(self.__dict__) == getter(other.__dict__)
        else:
            return NotImplemented  # delegate to the other item in the comparison


K = TypeVar('K')
V = TypeVar('V')


# We need a sentinel value for missing fields when comparing models
# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value
# we can't default to None
# We use the single-value enum trick to allow correct typing when using a sentinel
class _SentinelType(enum.Enum):
    SENTINEL = enum.auto()


_SENTINEL = _SentinelType.SENTINEL


@dataclasses.dataclass
class _SafeGetItemProxy(Generic[K, V]):
    """Wrapper redirecting `__getitem__` to `get` and a sentinel value

    This makes is safe to use in `operator.itemgetter` when some keys may be missing
    """

    wrapped: dict[K, V]

    def __getitem__(self, key: K, /) -> V | _SentinelType:
        return self.wrapped.get(key, _SENTINEL)

    def __contains__(self, key: K, /) -> bool:
        return self.wrapped.__contains__(key)


class SafeItemGetterEqModelFastPath(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            # Perform common checks first
            if not (
                self_type == other_type
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            ):
                return False

            # Fix GH-7444 by comparing only pydantic fields
            # We provide a fast-path for performance: __dict__ comparison is *much* faster
            # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks
            if self.__dict__ == other.__dict__:
                # If the check above passes, then pydantic fields are equal, we can return early
                return True
            else:
                # Else, we need to perform a more detailed, costlier comparison
                model_fields = type(self).model_fields.keys()
                getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None
                return getter(_SafeGetItemProxy(self.__dict__)) == getter(_SafeGetItemProxy(other.__dict__))
        else:
            return NotImplemented  # delegate to the other item in the comparison


class ItemGetterEqModelFastPathFallback(pydantic.BaseModel, frozen=True):
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, pydantic.BaseModel):
            # When comparing instances of generic types for equality, as long as all field values are equal,
            # only require their generic origin types to be equal, rather than exact type equality.
            # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
            self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
            other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

            # Perform common checks first
            if not (
                self_type == other_type
                and self.__pydantic_private__ == other.__pydantic_private__
                and self.__pydantic_extra__ == other.__pydantic_extra__
            ):
                return False

            # Fix GH-7444 by comparing only pydantic fields
            # We provide a fast-path for performance: __dict__ comparison is *much* faster
            # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks
            if self.__dict__ == other.__dict__:
                # If the check above passes, then pydantic fields are equal, we can return early
                return True
            else:
                # Else, we need to perform a more detailed, costlier comparison
                model_fields = type(self).model_fields.keys()
                getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None
                try:
                    return getter(self.__dict__) == getter(other.__dict__)
                except KeyError:
                    return getter(_SafeGetItemProxy(self.__dict__)) == getter(_SafeGetItemProxy(other.__dict__))
        else:
            return NotImplemented  # delegate to the other item in the comparison


IMPLEMENTATIONS = {
    # Commented out because it is too slow for benchmark to complete in reasonable time
    # "dict comprehension": DictComprehensionEqModel,
    'itemgetter': ItemGetterEqModel,
    'itemgetter+fastpath': ItemGetterEqModelFastPath,
    # Commented-out because it is too slow to run with run_benchmark_random_unequal
    #'itemgetter+safety+fastpath': SafeItemGetterEqModelFastPath,
    'itemgetter+fastpath+safe-fallback': ItemGetterEqModelFastPathFallback,
}

# Benchmark running & plotting code


def plot_all_benchmark(
    bases: dict[str, type[pydantic.BaseModel]],
    sizes: list[int],
) -> figure.Figure:
    import matplotlib.pyplot as plt

    n_rows, n_cols = len(BENCHMARKS), 2
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 4))

    for row, (name, benchmark) in enumerate(BENCHMARKS.items()):
        for col, mimic_cached_property in enumerate([False, True]):
            plot_benchmark(
                f'{name}, {mimic_cached_property=}',
                benchmark,
                bases=bases,
                sizes=sizes,
                mimic_cached_property=mimic_cached_property,
                ax=axes[row, col],
            )
    for ax in axes.ravel():
        ax.legend()
    fig.suptitle(f'python {PYTHON_VERSION}, pydantic {PYDANTIC_VERSION}')
    return fig


def plot_benchmark(
    title: str,
    benchmark: Callable,
    bases: dict[str, type[pydantic.BaseModel]],
    sizes: list[int],
    mimic_cached_property: bool,
    ax: axes.Axes | None = None,
):
    import matplotlib.pyplot as plt
    import numpy as np

    ax = ax or plt.gca()
    arr_sizes = np.asarray(sizes)

    baseline = benchmark(
        title=f'{title}, baseline',
        base=OldImplementationModel,
        sizes=sizes,
        mimic_cached_property=mimic_cached_property,
    )
    ax.plot(sizes, baseline / baseline, label='baseline')
    for name, base in bases.items():
        times = benchmark(
            title=f'{title}, {name}',
            base=base,
            sizes=sizes,
            mimic_cached_property=mimic_cached_property,
        )
        mask_valid = ~np.isnan(times)
        ax.plot(arr_sizes[mask_valid], times[mask_valid] / baseline[mask_valid], label=name)

    ax.set_title(title)
    ax.set_xlabel('Number of pydantic fields')
    ax.set_ylabel('Average time relative to baseline')
    return ax


class SizedIterable(Sized, Iterable):
    pass


def run_benchmark_nodiff(
    title: str,
    base: type[pydantic.BaseModel],
    sizes: SizedIterable,
    mimic_cached_property: bool,
    n_execution: int = 10_000,
    n_repeat: int = 5,
) -> np.ndarray:
    setup = textwrap.dedent(
        """
        import pydantic

        Model = pydantic.create_model(
            "Model",
            __base__=Base,
            **{f'x{i}': (int, i) for i in range(%(size)d)}
        )
        left = Model()
        right = Model()
        """
    )
    if mimic_cached_property:
        # Mimic functools.cached_property editing __dict__
        # NOTE: we must edit both objects, otherwise the dict don't have the same size and
        # dict.__eq__ has a very fast path. This makes our timing comparison incorrect
        # However, the value must be different, otherwise *our* __dict__ == right.__dict__
        # fast-path prevents our correct code from running
        setup += textwrap.dedent(
            """
            object.__setattr__(left, 'cache', None)
            object.__setattr__(right, 'cache', -1)
            """
        )
    statement = 'left == right'
    namespace = {'Base': base}
    return run_benchmark(
        title,
        setup=setup,
        statement=statement,
        n_execution=n_execution,
        n_repeat=n_repeat,
        globals=namespace,
        params={'size': sizes},
    )


def run_benchmark_first_diff(
    title: str,
    base: type[pydantic.BaseModel],
    sizes: SizedIterable,
    mimic_cached_property: bool,
    n_execution: int = 10_000,
    n_repeat: int = 5,
) -> np.ndarray:
    setup = textwrap.dedent(
        """
        import pydantic

        Model = pydantic.create_model(
            "Model",
            __base__=Base,
            **{f'x{i}': (int, i) for i in range(%(size)d)}
        )
        left = Model()
        right = Model(f0=-1) if %(size)d > 0 else Model()
        """
    )
    if mimic_cached_property:
        # Mimic functools.cached_property editing __dict__
        # NOTE: we must edit both objects, otherwise the dict don't have the same size and
        # dict.__eq__ has a very fast path. This makes our timing comparison incorrect
        # However, the value must be different, otherwise *our* __dict__ == right.__dict__
        # fast-path prevents our correct code from running
        setup += textwrap.dedent(
            """
            object.__setattr__(left, 'cache', None)
            object.__setattr__(right, 'cache', -1)
            """
        )
    statement = 'left == right'
    namespace = {'Base': base}
    return run_benchmark(
        title,
        setup=setup,
        statement=statement,
        n_execution=n_execution,
        n_repeat=n_repeat,
        globals=namespace,
        params={'size': sizes},
    )


def run_benchmark_last_diff(
    title: str,
    base: type[pydantic.BaseModel],
    sizes: SizedIterable,
    mimic_cached_property: bool,
    n_execution: int = 10_000,
    n_repeat: int = 5,
) -> np.ndarray:
    setup = textwrap.dedent(
        """
        import pydantic

        Model = pydantic.create_model(
            "Model",
            __base__=Base,
            # shift the range() so that there is a field named size
            **{f'x{i}': (int, i) for i in range(1, %(size)d + 1)}
        )
        left = Model()
        right = Model(f%(size)d=-1) if %(size)d > 0 else Model()
        """
    )
    if mimic_cached_property:
        # Mimic functools.cached_property editing __dict__
        # NOTE: we must edit both objects, otherwise the dict don't have the same size and
        # dict.__eq__ has a very fast path. This makes our timing comparison incorrect
        # However, the value must be different, otherwise *our* __dict__ == right.__dict__
        # fast-path prevents our correct code from running
        setup += textwrap.dedent(
            """
            object.__setattr__(left, 'cache', None)
            object.__setattr__(right, 'cache', -1)
            """
        )
    statement = 'left == right'
    namespace = {'Base': base}
    return run_benchmark(
        title,
        setup=setup,
        statement=statement,
        n_execution=n_execution,
        n_repeat=n_repeat,
        globals=namespace,
        params={'size': sizes},
    )


def run_benchmark_random_unequal(
    title: str,
    base: type[pydantic.BaseModel],
    sizes: SizedIterable,
    mimic_cached_property: bool,
    n_samples: int = 100,
    n_execution: int = 1_000,
    n_repeat: int = 5,
) -> np.ndarray:
    import numpy as np

    setup = textwrap.dedent(
        """
        import pydantic

        Model = pydantic.create_model(
            "Model",
            __base__=Base,
            **{f'x{i}': (int, i) for i in range(%(size)d)}
        )
        left = Model()
        right = Model(f%(field)d=-1)
        """
    )
    if mimic_cached_property:
        # Mimic functools.cached_property editing __dict__
        # NOTE: we must edit both objects, otherwise the dict don't have the same size and
        # dict.__eq__ has a very fast path. This makes our timing comparison incorrect
        # However, the value must be different, otherwise *our* __dict__ == right.__dict__
        # fast-path prevents our correct code from running
        setup += textwrap.dedent(
            """
            object.__setattr__(left, 'cache', None)
            object.__setattr__(right, 'cache', -1)
            """
        )
    statement = 'left == right'
    namespace = {'Base': base}
    arr_sizes = np.fromiter(sizes, dtype=int)
    mask_valid_sizes = arr_sizes > 0
    arr_valid_sizes = arr_sizes[mask_valid_sizes]  # we can't support 0 when sampling the field
    rng = np.random.default_rng()
    arr_fields = rng.integers(arr_valid_sizes, size=(n_samples, *arr_valid_sizes.shape))
    # broadcast the sizes against their sample, so we can iterate on (size, field) tuple
    # as parameters of the timing test
    arr_size_broadcast, _ = np.meshgrid(arr_valid_sizes, arr_fields[:, 0])

    results = run_benchmark(
        title,
        setup=setup,
        statement=statement,
        n_execution=n_execution,
        n_repeat=n_repeat,
        globals=namespace,
        params={'size': arr_size_broadcast.ravel(), 'field': arr_fields.ravel()},
    )
    times = np.empty(arr_sizes.shape, dtype=float)
    times[~mask_valid_sizes] = np.nan
    times[mask_valid_sizes] = results.reshape((n_samples, *arr_valid_sizes.shape)).mean(axis=0)
    return times


BENCHMARKS = {
    'All field equals': run_benchmark_nodiff,
    'First field unequal': run_benchmark_first_diff,
    'Last field unequal': run_benchmark_last_diff,
    'Random unequal field': run_benchmark_random_unequal,
}


def run_benchmark(
    title: str,
    setup: str,
    statement: str,
    n_execution: int = 10_000,
    n_repeat: int = 5,
    globals: dict[str, Any] | None = None,
    progress_bar: bool = True,
    params: dict[str, SizedIterable] | None = None,
) -> np.ndarray:
    import numpy as np
    import tqdm.auto as tqdm

    namespace = globals or {}
    # fast-path
    if not params:
        length = 1
        packed_params = [()]
    else:
        length = len(next(iter(params.values())))
        # This iterator yields a tuple of (key, value) pairs
        # First, make a list of N iterator over (key, value), where the provided values are iterated
        param_pairs = [zip(itertools.repeat(name), value) for name, value in params.items()]
        # Then pack our individual parameter iterator into one
        packed_params = zip(*param_pairs)

    times = [
        # Take the min of the repeats as recommended by timeit doc
        min(
            timeit.Timer(
                setup=setup % dict(local_params),
                stmt=statement,
                globals=namespace,
            ).repeat(repeat=n_repeat, number=n_execution)
        )
        / n_execution
        for local_params in tqdm.tqdm(packed_params, desc=title, total=length, disable=not progress_bar)
    ]
    gc.collect()

    return np.asarray(times, dtype=float)


if __name__ == '__main__':
    # run with `uv run tests/benchmarks/test_basemodel_eq_performance.py`
    import argparse
    import pathlib

    try:
        import matplotlib  # noqa: F401
        import numpy  # noqa: F401
        import tqdm  # noqa: F401
    except ImportError as err:
        raise ImportError(
            'This benchmark additionally depends on numpy, matplotlib and tqdm. '
            'Install those in your environment to run the benchmark.'
        ) from err

    parser = argparse.ArgumentParser(
        description='Test the performance of various BaseModel.__eq__ implementations fixing GH-7444.'
    )
    parser.add_argument(
        '-o',
        '--output-path',
        default=None,
        type=pathlib.Path,
        help=(
            'Output directory or file in which to save the benchmark results. '
            'If a directory is passed, a default filename is used.'
        ),
    )
    parser.add_argument(
        '--min-n-fields',
        type=int,
        default=0,
        help=('Test the performance of BaseModel.__eq__ on models with at least this number of fields. Defaults to 0.'),
    )
    parser.add_argument(
        '--max-n-fields',
        type=int,
        default=100,
        help=('Test the performance of BaseModel.__eq__ on models with up to this number of fields. Defaults to 100.'),
    )

    args = parser.parse_args()

    import matplotlib.pyplot as plt

    sizes = list(range(args.min_n_fields, args.max_n_fields))
    fig = plot_all_benchmark(IMPLEMENTATIONS, sizes=sizes)
    plt.tight_layout()

    if args.output_path is None:
        plt.show()
    else:
        if args.output_path.suffix:
            filepath = args.output_path
        else:
            filepath = args.output_path / f'eq-benchmark_python-{PYTHON_VERSION.replace(".", "-")}.png'
        fig.savefig(
            filepath,
            dpi=200,
            facecolor='white',
            transparent=False,
        )
        print(f'wrote {filepath!s}', file=sys.stderr)
