"""
Rudimentary Apache Arrow-backed ExtensionArray.

At the moment, just a boolean array / type is implemented.
Eventually, we'll want to parametrize the type and support
multiple dtypes. Not all methods are implemented yet, and the
current implementation is not efficient.
"""
from __future__ import annotations

import itertools
import operator

import numpy as np
import pyarrow as pa

from pandas._typing import type_t

import pandas as pd
from pandas.api.extensions import (
    ExtensionDtype,
    register_extension_dtype,
    take,
)
from pandas.api.types import is_scalar
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
from pandas.core.construction import extract_array


@register_extension_dtype
class ArrowBoolDtype(ExtensionDtype):

    type = np.bool_
    kind = "b"
    name = "arrow_bool"
    na_value = pa.NULL

    @classmethod
    def construct_array_type(cls) -> type_t[ArrowBoolArray]:
        """
        Return the array type associated with this dtype.

        Returns
        -------
        type
        """
        return ArrowBoolArray

    @property
    def _is_boolean(self) -> bool:
        return True


@register_extension_dtype
class ArrowStringDtype(ExtensionDtype):

    type = str
    kind = "U"
    name = "arrow_string"
    na_value = pa.NULL

    @classmethod
    def construct_array_type(cls) -> type_t[ArrowStringArray]:
        """
        Return the array type associated with this dtype.

        Returns
        -------
        type
        """
        return ArrowStringArray


class ArrowExtensionArray(_ArrowExtensionArray):
    _data: pa.ChunkedArray

    @classmethod
    def _from_sequence(cls, values, dtype=None, copy=False):
        # TODO: respect dtype, copy

        if isinstance(values, cls):
            # in particular for empty cases the pa.array(np.asarray(...))
            #  does not round-trip
            return cls(values._data)

        elif not len(values):
            if isinstance(values, list):
                dtype = bool if cls is ArrowBoolArray else str
                values = np.array([], dtype=dtype)

        arr = pa.chunked_array([pa.array(np.asarray(values))])
        return cls(arr)

    def __repr__(self):
        return f"{type(self).__name__}({repr(self._data)})"

    def __contains__(self, obj) -> bool:
        if obj is None or obj is self.dtype.na_value:
            # None -> EA.__contains__ only checks for self._dtype.na_value, not
            #  any compatible NA value.
            # self.dtype.na_value -> <pa.NullScalar:None> isn't recognized by pd.isna
            return bool(self.isna().any())
        return bool(super().__contains__(obj))

    def __getitem__(self, item):
        if is_scalar(item):
            return self._data.to_pandas()[item]
        else:
            vals = self._data.to_pandas()[item]
            return type(self)._from_sequence(vals)

    def astype(self, dtype, copy=True):
        # needed to fix this astype for the Series constructor.
        if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
            if copy:
                return self.copy()
            return self
        return super().astype(dtype, copy)

    @property
    def dtype(self):
        return self._dtype

    def _logical_method(self, other, op):
        if not isinstance(other, type(self)):
            raise NotImplementedError()

        result = op(np.array(self._data), np.array(other._data))
        return ArrowBoolArray(
            pa.chunked_array([pa.array(result, mask=pd.isna(self._data.to_pandas()))])
        )

    def __eq__(self, other):
        if not isinstance(other, type(self)):
            # TODO: use some pyarrow function here?
            return np.asarray(self).__eq__(other)

        return self._logical_method(other, operator.eq)

    def take(self, indices, allow_fill=False, fill_value=None):
        data = self._data.to_pandas()
        data = extract_array(data, extract_numpy=True)

        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value

        result = take(data, indices, fill_value=fill_value, allow_fill=allow_fill)
        return self._from_sequence(result, dtype=self.dtype)

    @classmethod
    def _concat_same_type(cls, to_concat):
        chunks = list(itertools.chain.from_iterable(x._data.chunks for x in to_concat))
        arr = pa.chunked_array(chunks)
        return cls(arr)

    def __invert__(self):
        return type(self)._from_sequence(~self._data.to_pandas())

    def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
        if skipna:
            arr = self[~self.isna()]
        else:
            arr = self

        try:
            op = getattr(arr, name)
        except AttributeError as err:
            raise TypeError from err
        return op(**kwargs)

    def any(self, axis=0, out=None):
        # Explicitly return a plain bool to reproduce GH-34660
        return bool(self._data.to_pandas().any())

    def all(self, axis=0, out=None):
        # Explicitly return a plain bool to reproduce GH-34660
        return bool(self._data.to_pandas().all())


class ArrowBoolArray(ArrowExtensionArray):
    def __init__(self, values) -> None:
        if not isinstance(values, pa.ChunkedArray):
            raise ValueError

        assert values.type == pa.bool_()
        self._data = values
        self._dtype = ArrowBoolDtype()  # type: ignore[assignment]


class ArrowStringArray(ArrowExtensionArray):
    def __init__(self, values) -> None:
        if not isinstance(values, pa.ChunkedArray):
            raise ValueError

        assert values.type == pa.string()
        self._data = values
        self._dtype = ArrowStringDtype()  # type: ignore[assignment]
