from __future__ import annotations

import decimal
import numbers
import random
import sys

import numpy as np

from pandas._typing import type_t

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
    is_dtype_equal,
    is_float,
    pandas_dtype,
)

import pandas as pd
from pandas.api.extensions import (
    no_default,
    register_extension_dtype,
)
from pandas.api.types import (
    is_list_like,
    is_scalar,
)
from pandas.core import arraylike
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import (
    ExtensionArray,
    ExtensionScalarOpsMixin,
)
from pandas.core.indexers import check_array_indexer


@register_extension_dtype
class DecimalDtype(ExtensionDtype):
    type = decimal.Decimal
    name = "decimal"
    na_value = decimal.Decimal("NaN")
    _metadata = ("context",)

    def __init__(self, context=None) -> None:
        self.context = context or decimal.getcontext()

    def __repr__(self) -> str:
        return f"DecimalDtype(context={self.context})"

    @classmethod
    def construct_array_type(cls) -> type_t[DecimalArray]:
        """
        Return the array type associated with this dtype.

        Returns
        -------
        type
        """
        return DecimalArray

    @property
    def _is_numeric(self) -> bool:
        return True


class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
    __array_priority__ = 1000

    def __init__(self, values, dtype=None, copy=False, context=None) -> None:
        for i, val in enumerate(values):
            if is_float(val):
                if np.isnan(val):
                    values[i] = DecimalDtype.na_value
                else:
                    values[i] = DecimalDtype.type(val)
            elif not isinstance(val, decimal.Decimal):
                raise TypeError("All values must be of type " + str(decimal.Decimal))
        values = np.asarray(values, dtype=object)

        self._data = values
        # Some aliases for common attribute names to ensure pandas supports
        # these
        self._items = self.data = self._data
        # those aliases are currently not working due to assumptions
        # in internal code (GH-20735)
        # self._values = self.values = self.data
        self._dtype = DecimalDtype(context)

    @property
    def dtype(self):
        return self._dtype

    @classmethod
    def _from_sequence(cls, scalars, dtype=None, copy=False):
        return cls(scalars)

    @classmethod
    def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
        return cls._from_sequence([decimal.Decimal(x) for x in strings], dtype, copy)

    @classmethod
    def _from_factorized(cls, values, original):
        return cls(values)

    _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)

    def to_numpy(
        self,
        dtype=None,
        copy: bool = False,
        na_value: object = no_default,
        decimals=None,
    ) -> np.ndarray:
        result = np.asarray(self, dtype=dtype)
        if decimals is not None:
            result = np.asarray([round(x, decimals) for x in result])
        return result

    def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
        #
        if not all(
            isinstance(t, self._HANDLED_TYPES + (DecimalArray,)) for t in inputs
        ):
            return NotImplemented

        result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
            self, ufunc, method, *inputs, **kwargs
        )
        if result is not NotImplemented:
            # e.g. test_array_ufunc_series_scalar_other
            return result

        if "out" in kwargs:
            return arraylike.dispatch_ufunc_with_out(
                self, ufunc, method, *inputs, **kwargs
            )

        inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
        result = getattr(ufunc, method)(*inputs, **kwargs)

        if method == "reduce":
            result = arraylike.dispatch_reduction_ufunc(
                self, ufunc, method, *inputs, **kwargs
            )
            if result is not NotImplemented:
                return result

        def reconstruct(x):
            if isinstance(x, (decimal.Decimal, numbers.Number)):
                return x
            else:
                return DecimalArray._from_sequence(x)

        if ufunc.nout > 1:
            return tuple(reconstruct(x) for x in result)
        else:
            return reconstruct(result)

    def __getitem__(self, item):
        if isinstance(item, numbers.Integral):
            return self._data[item]
        else:
            # array, slice.
            item = pd.api.indexers.check_array_indexer(self, item)
            return type(self)(self._data[item])

    def take(self, indexer, allow_fill=False, fill_value=None):
        from pandas.api.extensions import take

        data = self._data
        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value

        result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill)
        return self._from_sequence(result)

    def copy(self):
        return type(self)(self._data.copy(), dtype=self.dtype)

    def astype(self, dtype, copy=True):
        if is_dtype_equal(dtype, self._dtype):
            if not copy:
                return self
        dtype = pandas_dtype(dtype)
        if isinstance(dtype, type(self.dtype)):
            return type(self)(self._data, copy=copy, context=dtype.context)

        return super().astype(dtype, copy=copy)

    def __setitem__(self, key, value):
        if is_list_like(value):
            if is_scalar(key):
                raise ValueError("setting an array element with a sequence.")
            value = [decimal.Decimal(v) for v in value]
        else:
            value = decimal.Decimal(value)

        key = check_array_indexer(self, key)
        self._data[key] = value

    def __len__(self) -> int:
        return len(self._data)

    def __contains__(self, item) -> bool | np.bool_:
        if not isinstance(item, decimal.Decimal):
            return False
        elif item.is_nan():
            return self.isna().any()
        else:
            return super().__contains__(item)

    @property
    def nbytes(self) -> int:
        n = len(self)
        if n:
            return n * sys.getsizeof(self[0])
        return 0

    def isna(self):
        return np.array([x.is_nan() for x in self._data], dtype=bool)

    @property
    def _na_value(self):
        return decimal.Decimal("NaN")

    def _formatter(self, boxed=False):
        if boxed:
            return "Decimal: {}".format
        return repr

    @classmethod
    def _concat_same_type(cls, to_concat):
        return cls(np.concatenate([x._data for x in to_concat]))

    def _reduce(self, name: str, *, skipna: bool = True, **kwargs):

        if skipna:
            # If we don't have any NAs, we can ignore skipna
            if self.isna().any():
                other = self[~self.isna()]
                return other._reduce(name, **kwargs)

        if name == "sum" and len(self) == 0:
            # GH#29630 avoid returning int 0 or np.bool_(False) on old numpy
            return decimal.Decimal(0)

        try:
            op = getattr(self.data, name)
        except AttributeError as err:
            raise NotImplementedError(
                f"decimal does not support the {name} operation"
            ) from err
        return op(axis=0)

    def _cmp_method(self, other, op):
        # For use with OpsMixin
        def convert_values(param):
            if isinstance(param, ExtensionArray) or is_list_like(param):
                ovalues = param
            else:
                # Assume it's an object
                ovalues = [param] * len(self)
            return ovalues

        lvalues = self
        rvalues = convert_values(other)

        # If the operator is not defined for the underlying objects,
        # a TypeError should be raised
        res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]

        return np.asarray(res, dtype=bool)

    def value_counts(self, dropna: bool = True):
        from pandas.core.algorithms import value_counts

        return value_counts(self.to_numpy(), dropna=dropna)


def to_decimal(values, context=None):
    return DecimalArray([decimal.Decimal(x) for x in values], context=context)


def make_data():
    return [decimal.Decimal(random.random()) for _ in range(100)]


DecimalArray._add_arithmetic_ops()
