"""
Handles dispatching array operations to the correct backend library, as well
as converting arrays to backend formats and then potentially storing them as
constants.
"""

import importlib

import numpy

from . import object_arrays
from . import cupy as _cupy
from . import jax as _jax
from . import tensorflow as _tensorflow
from . import theano as _theano
from . import torch as _torch

__all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"]

# known non top-level imports
_aliases = {
    'dask': 'dask.array',
    'theano': 'theano.tensor',
    'torch': 'opt_einsum.backends.torch',
    'jax': 'jax.numpy',
    'autograd': 'autograd.numpy',
    'mars': 'mars.tensor',
}


def _import_func(func, backend, default=None):
    """Try and import ``{backend}.{func}``.
    If library is installed and func is found, return the func;
    otherwise if default is provided, return default;
    otherwise raise an error.
    """
    try:
        lib = importlib.import_module(_aliases.get(backend, backend))
        return getattr(lib, func) if default is None else getattr(lib, func, default)
    except AttributeError:
        error_msg = ("{} doesn't seem to provide the function {} - see "
                     "https://optimized-einsum.readthedocs.io/en/latest/backends.html "
                     "for details on which functions are required for which contractions.")
        raise AttributeError(error_msg.format(backend, func))


# manually cache functions as python2 doesn't support functools.lru_cache
#     other libs will be added to this if needed, but pre-populate with numpy
_cached_funcs = {
    ('tensordot', 'numpy'): numpy.tensordot,
    ('transpose', 'numpy'): numpy.transpose,
    ('einsum', 'numpy'): numpy.einsum,
    # also pre-populate with the arbitrary object backend
    ('tensordot', 'object'): numpy.tensordot,
    ('transpose', 'object'): numpy.transpose,
    ('einsum', 'object'): object_arrays.object_einsum,
}


def get_func(func, backend='numpy', default=None):
    """Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
    or a default func if provided. Cache result.
    """
    try:
        return _cached_funcs[func, backend]
    except KeyError:
        fn = _import_func(func, backend, default)
        _cached_funcs[func, backend] = fn
        return fn


# mark libs with einsum, else try to use tensordot/tranpose as much as possible
_has_einsum = {}


def has_einsum(backend):
    """Check if ``{backend}.einsum`` exists, cache result for performance.
    """
    try:
        return _has_einsum[backend]
    except KeyError:
        try:
            get_func('einsum', backend)
            _has_einsum[backend] = True
        except AttributeError:
            _has_einsum[backend] = False

        return _has_einsum[backend]


_has_tensordot = {}


def has_tensordot(backend):
    """Check if ``{backend}.tensordot`` exists, cache result for performance.
    """
    try:
        return _has_tensordot[backend]
    except KeyError:
        try:
            get_func('tensordot', backend)
            _has_tensordot[backend] = True
        except AttributeError:
            _has_tensordot[backend] = False

        return _has_tensordot[backend]


# Dispatch to correct expression backend
#    these are the backends which support explicit to-and-from numpy conversion
CONVERT_BACKENDS = {
    'tensorflow': _tensorflow.build_expression,
    'theano': _theano.build_expression,
    'cupy': _cupy.build_expression,
    'torch': _torch.build_expression,
    'jax': _jax.build_expression,
}

EVAL_CONSTS_BACKENDS = {
    'tensorflow': _tensorflow.evaluate_constants,
    'theano': _theano.evaluate_constants,
    'cupy': _cupy.evaluate_constants,
    'torch': _torch.evaluate_constants,
    'jax': _jax.evaluate_constants,
}


def build_expression(backend, arrays, expr):
    """Build an expression, based on ``expr`` and initial arrays ``arrays``,
    that evaluates using backend ``backend``.
    """
    return CONVERT_BACKENDS[backend](arrays, expr)


def evaluate_constants(backend, arrays, expr):
    """Convert constant arrays to the correct backend, and perform as much of
    the contraction of ``expr`` with these as possible.
    """
    return EVAL_CONSTS_BACKENDS[backend](arrays, expr)


def has_backend(backend):
    """Checks if the backend is known.
    """
    return backend.lower() in CONVERT_BACKENDS
