# `_elementwise_iterative_method.py` includes tools for writing functions that
# - are vectorized to work elementwise on arrays,
# - implement non-trivial, iterative algorithms with a callback interface, and
# - return rich objects with iteration count, termination status, etc.
#
# Examples include:
# `scipy.optimize._chandrupatla._chandrupatla for scalar rootfinding,
# `scipy.optimize._chandrupatla._chandrupatla_minimize for scalar minimization,
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature.

import numpy as np
from ._util import _RichResult, _call_callback_maybe_halt

_ESIGNERR = -1
_ECONVERR = -2
_EVALUEERR = -3
_ECALLBACK = -4
_ECONVERGED = 0
_EINPROGRESS = 1

def _initialize(func, xs, args, complex_ok=False, preserve_shape=None):
    """Initialize abscissa, function, and args arrays for elementwise function

    Parameters
    ----------
    func : callable
        An elementwise function with signature

            func(x: ndarray, *args) -> ndarray

        where each element of ``x`` is a finite real and ``args`` is a tuple,
        which may contain an arbitrary number of arrays that are broadcastable
        with ``x``.
    xs : tuple of arrays
        Finite real abscissa arrays. Must be broadcastable.
    args : tuple, optional
        Additional positional arguments to be passed to `func`.
    preserve_shape : bool, default:False
        When ``preserve_shape=False`` (default), `func` may be passed
        arguments of any shape; `_scalar_optimization_loop` is permitted
        to reshape and compress arguments at will. When
        ``preserve_shape=False``, arguments passed to `func` must have shape
        `shape` or ``shape + (n,)``, where ``n`` is any integer.

    Returns
    -------
    xs, fs, args : tuple of arrays
        Broadcasted, writeable, 1D abscissa and function value arrays (or
        NumPy floats, if appropriate). The dtypes of the `xs` and `fs` are
        `xfat`; the dtype of the `args` are unchanged.
    shape : tuple of ints
        Original shape of broadcasted arrays.
    xfat : NumPy dtype
        Result dtype of abscissae, function values, and args determined using
        `np.result_type`, except integer types are promoted to `np.float64`.

    Raises
    ------
    ValueError
        If the result dtype is not that of a real scalar

    Notes
    -----
    Useful for initializing the input of SciPy functions that accept
    an elementwise callable, abscissae, and arguments; e.g.
    `scipy.optimize._chandrupatla`.
    """
    nx = len(xs)

    # Try to preserve `dtype`, but we need to ensure that the arguments are at
    # least floats before passing them into the function; integers can overflow
    # and cause failure.
    # There might be benefit to combining the `xs` into a single array and
    # calling `func` once on the combined array. For now, keep them separate.
    xas = np.broadcast_arrays(*xs, *args)  # broadcast and rename
    xat = np.result_type(*[xa.dtype for xa in xas])
    xat = np.float64 if np.issubdtype(xat, np.integer) else xat
    xs, args = xas[:nx], xas[nx:]
    xs = [x.astype(xat, copy=False)[()] for x in xs]
    fs = [np.asarray(func(x, *args)) for x in xs]
    shape = xs[0].shape
    fshape = fs[0].shape

    if preserve_shape:
        # bind original shape/func now to avoid late-binding gotcha
        def func(x, *args, shape=shape, func=func,  **kwargs):
            i = (0,)*(len(fshape) - len(shape))
            return func(x[i], *args, **kwargs)
        shape = np.broadcast_shapes(fshape, shape)
        xs = [np.broadcast_to(x, shape) for x in xs]
        args = [np.broadcast_to(arg, shape) for arg in args]

    message = ("The shape of the array returned by `func` must be the same as "
               "the broadcasted shape of `x` and all other `args`.")
    if preserve_shape is not None:  # only in tanhsinh for now
        message = f"When `preserve_shape=False`, {message.lower()}"
    shapes_equal = [f.shape == shape for f in fs]
    if not np.all(shapes_equal):
        raise ValueError(message)

    # These algorithms tend to mix the dtypes of the abscissae and function
    # values, so figure out what the result will be and convert them all to
    # that type from the outset.
    xfat = np.result_type(*([f.dtype for f in fs] + [xat]))
    if not complex_ok and not np.issubdtype(xfat, np.floating):
        raise ValueError("Abscissae and function output must be real numbers.")
    xs = [x.astype(xfat, copy=True)[()] for x in xs]
    fs = [f.astype(xfat, copy=True)[()] for f in fs]

    # To ensure that we can do indexing, we'll work with at least 1d arrays,
    # but remember the appropriate shape of the output.
    xs = [x.ravel() for x in xs]
    fs = [f.ravel() for f in fs]
    args = [arg.flatten() for arg in args]
    return func, xs, fs, args, shape, xfat


def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
          post_func_eval, check_termination, post_termination_check,
          customize_result, res_work_pairs, preserve_shape=False):
    """Main loop of a vectorized scalar optimization algorithm

    Parameters
    ----------
    work : _RichResult
        All variables that need to be retained between iterations. Must
        contain attributes `nit`, `nfev`, and `success`
    callback : callable
        User-specified callback function
    shape : tuple of ints
        The shape of all output arrays
    maxiter :
        Maximum number of iterations of the algorithm
    func : callable
        The user-specified callable that is being optimized or solved
    args : tuple
        Additional positional arguments to be passed to `func`.
    dtype : NumPy dtype
        The common dtype of all abscissae and function values
    pre_func_eval : callable
        A function that accepts `work` and returns `x`, the active elements
        of `x` at which `func` will be evaluated. May modify attributes
        of `work` with any algorithmic steps that need to happen
         at the beginning of an iteration, before `func` is evaluated,
    post_func_eval : callable
        A function that accepts `x`, `func(x)`, and `work`. May modify
        attributes of `work` with any algorithmic steps that need to happen
         in the middle of an iteration, after `func` is evaluated but before
         the termination check.
    check_termination : callable
        A function that accepts `work` and returns `stop`, a boolean array
        indicating which of the active elements have met a termination
        condition.
    post_termination_check : callable
        A function that accepts `work`. May modify `work` with any algorithmic
        steps that need to happen after the termination check and before the
        end of the iteration.
    customize_result : callable
        A function that accepts `res` and `shape` and returns `shape`. May
        modify `res` (in-place) according to preferences (e.g. rearrange
        elements between attributes) and modify `shape` if needed.
    res_work_pairs : list of (str, str)
        Identifies correspondence between attributes of `res` and attributes
        of `work`; i.e., attributes of active elements of `work` will be
        copied to the appropriate indices of `res` when appropriate. The order
        determines the order in which _RichResult attributes will be
        pretty-printed.

    Returns
    -------
    res : _RichResult
        The final result object

    Notes
    -----
    Besides providing structure, this framework provides several important
    services for a vectorized optimization algorithm.

    - It handles common tasks involving iteration count, function evaluation
      count, a user-specified callback, and associated termination conditions.
    - It compresses the attributes of `work` to eliminate unnecessary
      computation on elements that have already converged.

    """
    cb_terminate = False

    # Initialize the result object and active element index array
    n_elements = int(np.prod(shape))
    active = np.arange(n_elements)  # in-progress element indices
    res_dict = {i: np.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
    res_dict['success'] = np.zeros(n_elements, dtype=bool)
    res_dict['status'] = np.full(n_elements, _EINPROGRESS)
    res_dict['nit'] = np.zeros(n_elements, dtype=int)
    res_dict['nfev'] = np.zeros(n_elements, dtype=int)
    res = _RichResult(res_dict)
    work.args = args

    active = _check_termination(work, res, res_work_pairs, active,
                                check_termination, preserve_shape)

    if callback is not None:
        temp = _prepare_result(work, res, res_work_pairs, active, shape,
                               customize_result, preserve_shape)
        if _call_callback_maybe_halt(callback, temp):
            cb_terminate = True

    while work.nit < maxiter and active.size and not cb_terminate and n_elements:
        x = pre_func_eval(work)

        if work.args and work.args[0].ndim != x.ndim:
            # `x` always starts as 1D. If the SciPy function that uses
            # _loop added dimensions to `x`, we need to
            # add them to the elements of `args`.
            dims = np.arange(x.ndim, dtype=np.int64)
            work.args = [np.expand_dims(arg, tuple(dims[arg.ndim:]))
                         for arg in work.args]

        x_shape = x.shape
        if preserve_shape:
            x = x.reshape(shape + (-1,))
        f = func(x, *work.args)
        f = np.asarray(f, dtype=dtype)
        if preserve_shape:
            x = x.reshape(x_shape)
            f = f.reshape(x_shape)
        work.nfev += 1 if x.ndim == 1 else x.shape[-1]

        post_func_eval(x, f, work)

        work.nit += 1
        active = _check_termination(work, res, res_work_pairs, active,
                                    check_termination, preserve_shape)

        if callback is not None:
            temp = _prepare_result(work, res, res_work_pairs, active, shape,
                                   customize_result, preserve_shape)
            if _call_callback_maybe_halt(callback, temp):
                cb_terminate = True
                break
        if active.size == 0:
            break

        post_termination_check(work)

    work.status[:] = _ECALLBACK if cb_terminate else _ECONVERR
    return _prepare_result(work, res, res_work_pairs, active, shape,
                           customize_result, preserve_shape)


def _check_termination(work, res, res_work_pairs, active, check_termination,
                       preserve_shape):
    # Checks termination conditions, updates elements of `res` with
    # corresponding elements of `work`, and compresses `work`.

    stop = check_termination(work)

    if np.any(stop):
        # update the active elements of the result object with the active
        # elements for which a termination condition has been met
        _update_active(work, res, res_work_pairs, active, stop, preserve_shape)

        if preserve_shape:
            stop = stop[active]

        proceed = ~stop
        active = active[proceed]

        if not preserve_shape:
            # compress the arrays to avoid unnecessary computation
            for key, val in work.items():
                work[key] = val[proceed] if isinstance(val, np.ndarray) else val
            work.args = [arg[proceed] for arg in work.args]

    return active


def _update_active(work, res, res_work_pairs, active, mask, preserve_shape):
    # Update `active` indices of the arrays in result object `res` with the
    # contents of the scalars and arrays in `update_dict`. When provided,
    # `mask` is a boolean array applied both to the arrays in `update_dict`
    # that are to be used and to the arrays in `res` that are to be updated.
    update_dict = {key1: work[key2] for key1, key2 in res_work_pairs}
    update_dict['success'] = work.status == 0

    if mask is not None:
        if preserve_shape:
            active_mask = np.zeros_like(mask)
            active_mask[active] = 1
            active_mask = active_mask & mask
            for key, val in update_dict.items():
                res[key][active_mask] = (val[active_mask] if np.size(val) > 1
                                         else val)
        else:
            active_mask = active[mask]
            for key, val in update_dict.items():
                res[key][active_mask] = val[mask] if np.size(val) > 1 else val
    else:
        for key, val in update_dict.items():
            if preserve_shape and not np.isscalar(val):
                val = val[active]
            res[key][active] = val


def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
                    preserve_shape):
    # Prepare the result object `res` by creating a copy, copying the latest
    # data from work, running the provided result customization function,
    # and reshaping the data to the original shapes.
    res = res.copy()
    _update_active(work, res, res_work_pairs, active, None, preserve_shape)

    shape = customize_result(res, shape)

    for key, val in res.items():
        res[key] = np.reshape(val, shape)[()]
    res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
    return _RichResult(**res)
