# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used by convolution layers."""

import itertools

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend


def convert_data_format(data_format, ndim):
    if data_format == "channels_last":
        if ndim == 3:
            return "NWC"
        elif ndim == 4:
            return "NHWC"
        elif ndim == 5:
            return "NDHWC"
        else:
            raise ValueError(
                f"Input rank not supported: {ndim}. "
                "Expected values are [3, 4, 5]"
            )
    elif data_format == "channels_first":
        if ndim == 3:
            return "NCW"
        elif ndim == 4:
            return "NCHW"
        elif ndim == 5:
            return "NCDHW"
        else:
            raise ValueError(
                f"Input rank not supported: {ndim}. "
                "Expected values are [3, 4, 5]"
            )
    else:
        raise ValueError(
            f"Invalid data_format: {data_format}. "
            'Expected values are ["channels_first", "channels_last"]'
        )


def normalize_tuple(value, n, name, allow_zero=False):
    """Transforms non-negative/positive integer/integers into an integer tuple.

    Args:
      value: The value to validate and convert. Could an int, or any iterable of
        ints.
      n: The size of the tuple to be returned.
      name: The name of the argument being validated, e.g. "strides" or
        "kernel_size". This is only used to format error messages.
      allow_zero: Default to False. A ValueError will raised if zero is received
        and this param is False.

    Returns:
      A tuple of n integers.

    Raises:
      ValueError: If something else than an int/long or iterable thereof or a
      negative value is
        passed.
    """
    error_msg = (
        f"The `{name}` argument must be a tuple of {n} "
        f"integers. Received: {value}"
    )

    if isinstance(value, int):
        value_tuple = (value,) * n
    else:
        try:
            value_tuple = tuple(value)
        except TypeError:
            raise ValueError(error_msg)
        if len(value_tuple) != n:
            raise ValueError(error_msg)
        for single_value in value_tuple:
            try:
                int(single_value)
            except (ValueError, TypeError):
                error_msg += (
                    f"including element {single_value} of "
                    f"type {type(single_value)}"
                )
                raise ValueError(error_msg)

    if allow_zero:
        unqualified_values = {v for v in value_tuple if v < 0}
        req_msg = ">= 0"
    else:
        unqualified_values = {v for v in value_tuple if v <= 0}
        req_msg = "> 0"

    if unqualified_values:
        error_msg += (
            f" including {unqualified_values}"
            f" that does not satisfy the requirement `{req_msg}`."
        )
        raise ValueError(error_msg)

    return value_tuple


def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
    """Determines output length of a convolution given input length.

    Args:
        input_length: integer.
        filter_size: integer.
        padding: one of "same", "valid", "full", "causal"
        stride: integer.
        dilation: dilation rate, integer.

    Returns:
        The output length (integer).
    """
    if input_length is None:
        return None
    assert padding in {"same", "valid", "full", "causal"}
    dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
    if padding in ["same", "causal"]:
        output_length = input_length
    elif padding == "valid":
        output_length = input_length - dilated_filter_size + 1
    elif padding == "full":
        output_length = input_length + dilated_filter_size - 1
    return (output_length + stride - 1) // stride


def conv_input_length(output_length, filter_size, padding, stride):
    """Determines input length of a convolution given output length.

    Args:
        output_length: integer.
        filter_size: integer.
        padding: one of "same", "valid", "full".
        stride: integer.

    Returns:
        The input length (integer).
    """
    if output_length is None:
        return None
    assert padding in {"same", "valid", "full"}
    if padding == "same":
        pad = filter_size // 2
    elif padding == "valid":
        pad = 0
    elif padding == "full":
        pad = filter_size - 1
    return (output_length - 1) * stride - 2 * pad + filter_size


def deconv_output_length(
    input_length,
    filter_size,
    padding,
    output_padding=None,
    stride=0,
    dilation=1,
):
    """Determines output length of a transposed convolution given input length.

    Args:
        input_length: Integer.
        filter_size: Integer.
        padding: one of `"same"`, `"valid"`, `"full"`.
        output_padding: Integer, amount of padding along the output dimension.
          Can be set to `None` in which case the output length is inferred.
        stride: Integer.
        dilation: Integer.

    Returns:
        The output length (integer).
    """
    assert padding in {"same", "valid", "full"}
    if input_length is None:
        return None

    # Get the dilated kernel size
    filter_size = filter_size + (filter_size - 1) * (dilation - 1)

    # Infer length if output padding is None, else compute the exact length
    if output_padding is None:
        if padding == "valid":
            length = input_length * stride + max(filter_size - stride, 0)
        elif padding == "full":
            length = input_length * stride - (stride + filter_size - 2)
        elif padding == "same":
            length = input_length * stride

    else:
        if padding == "same":
            pad = filter_size // 2
        elif padding == "valid":
            pad = 0
        elif padding == "full":
            pad = filter_size - 1

        length = (
            (input_length - 1) * stride + filter_size - 2 * pad + output_padding
        )
    return length


def normalize_data_format(value):
    if value is None:
        value = backend.image_data_format()
    data_format = value.lower()
    if data_format not in {"channels_first", "channels_last"}:
        raise ValueError(
            "The `data_format` argument must be one of "
            f'"channels_first", "channels_last". Received: {value}'
        )
    return data_format


def normalize_padding(value):
    if isinstance(value, (list, tuple)):
        return value
    padding = value.lower()
    if padding not in {"valid", "same", "causal"}:
        raise ValueError(
            "The `padding` argument must be a list/tuple or one of "
            '"valid", "same" (or "causal", only for `Conv1D). '
            f"Received: {padding}"
        )
    return padding


def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
    """Compute a mask representing the connectivity of a convolution operation.

    Assume a convolution with given parameters is applied to an input having N
    spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
    output with shape `(d_out1, ..., d_outN)`. This method returns a boolean
    array of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True`
    entries indicating pairs of input and output locations that are connected by
    a weight.

    Example:

      >>> input_shape = (4,)
      >>> kernel_shape = (2,)
      >>> strides = (1,)
      >>> padding = "valid"
      >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
      array([[ True, False, False],
             [ True,  True, False],
             [False,  True,  True],
             [False, False,  True]])

      where rows and columns correspond to inputs and outputs respectively.


    Args:
      input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
        input.
      kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
        receptive field.
      strides: tuple of size N, strides along each spatial dimension.
      padding: type of padding, string `"same"` or `"valid"`.
        `"valid"` means no padding. `"same"` results in padding evenly to
        the left/right or up/down of the input such that output has the same
        height/width dimension as the input.

    Returns:
      A boolean 2N-D `np.ndarray` of shape
      `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
      is the spatial shape of the output. `True` entries in the mask represent
      pairs of input-output locations that are connected by a weight.

    Raises:
      ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
          same number of dimensions.
      NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
    """
    if padding not in {"same", "valid"}:
        raise NotImplementedError(
            f"Padding type {padding} not supported. "
            'Only "valid" and "same" are implemented.'
        )

    in_dims = len(input_shape)
    if isinstance(kernel_shape, int):
        kernel_shape = (kernel_shape,) * in_dims
    if isinstance(strides, int):
        strides = (strides,) * in_dims

    kernel_dims = len(kernel_shape)
    stride_dims = len(strides)
    if kernel_dims != in_dims or stride_dims != in_dims:
        raise ValueError(
            "Number of strides, input and kernel dimensions must all "
            f"match. Received: stride_dims={stride_dims}, "
            f"in_dims={in_dims}, kernel_dims={kernel_dims}"
        )

    output_shape = conv_output_shape(
        input_shape, kernel_shape, strides, padding
    )

    mask_shape = input_shape + output_shape
    mask = np.zeros(mask_shape, np.bool)

    output_axes_ticks = [range(dim) for dim in output_shape]
    for output_position in itertools.product(*output_axes_ticks):
        input_axes_ticks = conv_connected_inputs(
            input_shape, kernel_shape, output_position, strides, padding
        )
        for input_position in itertools.product(*input_axes_ticks):
            mask[input_position + output_position] = True

    return mask


def conv_kernel_idxs(
    input_shape,
    kernel_shape,
    strides,
    padding,
    filters_in,
    filters_out,
    data_format,
):
    """Yields output-input tuples of indices in a CNN layer.

    The generator iterates over all `(output_idx, input_idx)` tuples, where
    `output_idx` is an integer index in a flattened tensor representing a single
    output image of a convolutional layer that is connected (via the layer
    weights) to the respective single input image at `input_idx`

    Example:

      >>> input_shape = (2, 2)
      >>> kernel_shape = (2, 1)
      >>> strides = (1, 1)
      >>> padding = "valid"
      >>> filters_in = 1
      >>> filters_out = 1
      >>> data_format = "channels_last"
      >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
      ...                       filters_in, filters_out, data_format))
      [(0, 0), (0, 2), (1, 1), (1, 3)]

    Args:
      input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
        input.
      kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
        receptive field.
      strides: tuple of size N, strides along each spatial dimension.
      padding: type of padding, string `"same"` or `"valid"`.
        `"valid"` means no padding. `"same"` results in padding evenly to
        the left/right or up/down of the input such that output has the same
        height/width dimension as the input.
      filters_in: `int`, number if filters in the input to the layer.
      filters_out: `int', number if filters in the output of the layer.
      data_format: string, "channels_first" or "channels_last".

    Yields:
      The next tuple `(output_idx, input_idx)`, where `output_idx` is an integer
      index in a flattened tensor representing a single output image of a
      convolutional layer that is connected (via the layer weights) to the
      respective single input image at `input_idx`.

    Raises:
        ValueError: if `data_format` is neither `"channels_last"` nor
          `"channels_first"`, or if number of strides, input, and kernel number
          of dimensions do not match.

        NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
    """
    if padding not in ("same", "valid"):
        raise NotImplementedError(
            f"Padding type {padding} not supported. "
            'Only "valid" and "same" are implemented.'
        )

    in_dims = len(input_shape)
    if isinstance(kernel_shape, int):
        kernel_shape = (kernel_shape,) * in_dims
    if isinstance(strides, int):
        strides = (strides,) * in_dims

    kernel_dims = len(kernel_shape)
    stride_dims = len(strides)
    if kernel_dims != in_dims or stride_dims != in_dims:
        raise ValueError(
            "Number of strides, input and kernel dimensions must all "
            f"match. Received: stride_dims={stride_dims}, "
            f"in_dims={in_dims}, kernel_dims={kernel_dims}"
        )

    output_shape = conv_output_shape(
        input_shape, kernel_shape, strides, padding
    )
    output_axes_ticks = [range(dim) for dim in output_shape]

    if data_format == "channels_first":
        concat_idxs = (
            lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
        )
    elif data_format == "channels_last":
        concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (
            filter_idx,
        )
    else:
        raise ValueError(
            f"Data format `{data_format}` not recognized."
            '`data_format` must be "channels_first" or "channels_last".'
        )

    for output_position in itertools.product(*output_axes_ticks):
        input_axes_ticks = conv_connected_inputs(
            input_shape, kernel_shape, output_position, strides, padding
        )
        for input_position in itertools.product(*input_axes_ticks):
            for f_in in range(filters_in):
                for f_out in range(filters_out):
                    out_idx = np.ravel_multi_index(
                        multi_index=concat_idxs(output_position, f_out),
                        dims=concat_idxs(output_shape, filters_out),
                    )
                    in_idx = np.ravel_multi_index(
                        multi_index=concat_idxs(input_position, f_in),
                        dims=concat_idxs(input_shape, filters_in),
                    )
                    yield (out_idx, in_idx)


def conv_connected_inputs(
    input_shape, kernel_shape, output_position, strides, padding
):
    """Return locations of the input connected to an output position.

    Assume a convolution with given parameters is applied to an input having N
    spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
    returns N ranges specifying the input region that was convolved with the
    kernel to produce the output at position
    `output_position = (p_out1, ..., p_outN)`.

    Example:

      >>> input_shape = (4, 4)
      >>> kernel_shape = (2, 1)
      >>> output_position = (1, 1)
      >>> strides = (1, 1)
      >>> padding = "valid"
      >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
      ...                       strides, padding)
      [range(1, 3), range(1, 2)]

    Args:
      input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
        input.
      kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
        receptive field.
      output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single
        position in the output of the convolution.
      strides: tuple of size N, strides along each spatial dimension.
      padding: type of padding, string `"same"` or `"valid"`.
        `"valid"` means no padding. `"same"` results in padding evenly to
        the left/right or up/down of the input such that output has the same
        height/width dimension as the input.

    Returns:
      N ranges `[[p_in_left1, ..., p_in_right1], ...,
                [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
      input connected to output_position.
    """
    ranges = []

    ndims = len(input_shape)
    for d in range(ndims):
        left_shift = int(kernel_shape[d] / 2)
        right_shift = kernel_shape[d] - left_shift

        center = output_position[d] * strides[d]

        if padding == "valid":
            center += left_shift

        start = max(0, center - left_shift)
        end = min(input_shape[d], center + right_shift)

        ranges.append(range(start, end))

    return ranges


def conv_output_shape(input_shape, kernel_shape, strides, padding):
    """Return the output shape of an N-D convolution.

    Forces dimensions where input is empty (size 0) to remain empty.

    Args:
      input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
        input.
      kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
        receptive field.
      strides: tuple of size N, strides along each spatial dimension.
      padding: type of padding, string `"same"` or `"valid"`.
        `"valid"` means no padding. `"same"` results in padding evenly to
        the left/right or up/down of the input such that output has the same
        height/width dimension as the input.

    Returns:
      tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
    """
    dims = range(len(kernel_shape))
    output_shape = [
        conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
        for d in dims
    ]
    output_shape = tuple(
        [0 if input_shape[d] == 0 else output_shape[d] for d in dims]
    )
    return output_shape


def squeeze_batch_dims(inp, op, inner_rank):
    """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.

    Where `squeeze_batch` reshapes `inp` to shape
    `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
    and `unsqueeze_batch` does the reverse reshape but on the output.

    Args:
      inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
        is length `inner_rank`.
      op: A callable that takes a single input tensor and returns a single.
        output tensor.
      inner_rank: A python integer.

    Returns:
      `unsqueeze_batch_op(squeeze_batch(inp))`.
    """
    with tf.name_scope("squeeze_batch_dims"):
        shape = inp.shape

        inner_shape = shape[-inner_rank:]
        if not inner_shape.is_fully_defined():
            inner_shape = tf.shape(inp)[-inner_rank:]

        batch_shape = shape[:-inner_rank]
        if not batch_shape.is_fully_defined():
            batch_shape = tf.shape(inp)[:-inner_rank]

        if isinstance(inner_shape, tf.TensorShape):
            inp_reshaped = tf.reshape(inp, [-1] + inner_shape.as_list())
        else:
            inp_reshaped = tf.reshape(
                inp, tf.concat(([-1], inner_shape), axis=-1)
            )

        out_reshaped = op(inp_reshaped)

        out_inner_shape = out_reshaped.shape[-inner_rank:]
        if not out_inner_shape.is_fully_defined():
            out_inner_shape = tf.shape(out_reshaped)[-inner_rank:]

        out = tf.reshape(
            out_reshaped, tf.concat((batch_shape, out_inner_shape), axis=-1)
        )

        out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
        return out
