# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Private utilities for locally-connected layers."""

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend
from keras.utils import conv_utils


def get_locallyconnected_mask(
    input_shape, kernel_shape, strides, padding, data_format
):
    """Return a mask representing connectivity of a locally-connected operation.

    This method returns a masking numpy array of 0s and 1s (of type
    `np.float32`) that, when element-wise multiplied with a fully-connected
    weight tensor, masks out the weights between disconnected input-output pairs
    and thus implements local connectivity through a sparse fully-connected
    weight tensor.

    Assume an unshared convolution with given parameters is applied to an input
    having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)`
    to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined
    by layer parameters such as `strides`).

    This method returns a mask which can be broadcast-multiplied (element-wise)
    with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer
    between (N+1)-D activations (N spatial + 1 channel dimensions for input and
    output) to make it perform an unshared convolution with given
    `kernel_shape`, `strides`, `padding` and `data_format`.

    Args:
      input_shape: tuple of size N: `(d_in1, ..., d_inN)` spatial shape of the
        input.
      kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
        receptive field.
      strides: tuple of size N, strides along each spatial dimension.
      padding: type of padding, string `"same"` or `"valid"`.
      data_format: a string, `"channels_first"` or `"channels_last"`.

    Returns:
      a `np.float32`-type `np.ndarray` of shape
      `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)`
      if `data_format == `"channels_first"`, or
      `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)`
      if `data_format == "channels_last"`.

    Raises:
      ValueError: if `data_format` is neither `"channels_first"` nor
                  `"channels_last"`.
    """
    mask = conv_utils.conv_kernel_mask(
        input_shape=input_shape,
        kernel_shape=kernel_shape,
        strides=strides,
        padding=padding,
    )

    ndims = int(mask.ndim / 2)

    if data_format == "channels_first":
        mask = np.expand_dims(mask, 0)
        mask = np.expand_dims(mask, -ndims - 1)

    elif data_format == "channels_last":
        mask = np.expand_dims(mask, ndims)
        mask = np.expand_dims(mask, -1)

    else:
        raise ValueError("Unrecognized data_format: " + str(data_format))

    return mask


def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
    """Apply N-D convolution with un-shared weights using a single matmul call.

    This method outputs `inputs . (kernel * kernel_mask)`
    (with `.` standing for matrix-multiply and `*` for element-wise multiply)
    and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and
    hence perform the same operation as a convolution with un-shared
    (the remaining entries in `kernel`) weights. It also does the necessary
    reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.

    Args:
        inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
          d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
        kernel: the unshared weights for N-D convolution,
            an (N+2)-D tensor of shape: `(d_in1, ..., d_inN, channels_in,
            d_out2, ..., d_outN, channels_out)` or `(channels_in, d_in1, ...,
            d_inN, channels_out, d_out2, ..., d_outN)`, with the ordering of
            channels and spatial dimensions matching that of the input. Each
            entry is the weight between a particular input and output location,
            similarly to a fully-connected weight matrix.
        kernel_mask: a float 0/1 mask tensor of shape: `(d_in1, ..., d_inN, 1,
          d_out2, ..., d_outN, 1)` or `(1, d_in1, ..., d_inN, 1, d_out2, ...,
          d_outN)`, with the ordering of singleton and spatial dimensions
          matching that of the input. Mask represents the connectivity pattern
          of the layer and is precomputed elsewhere based on layer parameters:
          stride, padding, and the receptive field shape.
        output_shape: a tuple of (N+2) elements representing the output shape:
          `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
          d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
          spatial dimensions matching that of the input.

    Returns:
        Output (N+2)-D tensor with shape `output_shape`.
    """
    inputs_flat = backend.reshape(inputs, (backend.shape(inputs)[0], -1))

    kernel = kernel_mask * kernel
    kernel = make_2d(kernel, split_dim=backend.ndim(kernel) // 2)

    output_flat = tf.matmul(inputs_flat, kernel, b_is_sparse=True)
    output = backend.reshape(
        output_flat,
        [
            backend.shape(output_flat)[0],
        ]
        + output_shape.as_list()[1:],
    )
    return output


def local_conv_sparse_matmul(
    inputs, kernel, kernel_idxs, kernel_shape, output_shape
):
    """Apply N-D convolution with un-shared weights using a single sparse matmul.

    This method outputs `inputs . tf.sparse.SparseTensor(indices=kernel_idxs,
    values=kernel, dense_shape=kernel_shape)`, with `.` standing for
    matrix-multiply. It also reshapes `inputs` to 2-D and `output` to (N+2)-D.

    Args:
        inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
          d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
        kernel: a 1-D tensor with shape `(len(kernel_idxs),)` containing all the
          weights of the layer.
        kernel_idxs:  a list of integer tuples representing indices in a sparse
          matrix performing the un-shared convolution as a matrix-multiply.
        kernel_shape: a tuple `(input_size, output_size)`, where `input_size =
          channels_in * d_in1 * ... * d_inN` and `output_size = channels_out *
          d_out1 * ... * d_outN`.
        output_shape: a tuple of (N+2) elements representing the output shape:
          `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
          d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
          spatial dimensions matching that of the input.

    Returns:
        Output (N+2)-D dense tensor with shape `output_shape`.
    """
    inputs_flat = backend.reshape(inputs, (backend.shape(inputs)[0], -1))
    output_flat = tf.sparse.sparse_dense_matmul(
        sp_a=tf.SparseTensor(kernel_idxs, kernel, kernel_shape),
        b=inputs_flat,
        adjoint_b=True,
    )
    output_flat_transpose = backend.transpose(output_flat)

    output_reshaped = backend.reshape(
        output_flat_transpose,
        [
            backend.shape(output_flat_transpose)[0],
        ]
        + output_shape.as_list()[1:],
    )
    return output_reshaped


def make_2d(tensor, split_dim):
    """Reshapes an N-dimensional tensor into a 2D tensor.

    Dimensions before (excluding) and after (including) `split_dim` are grouped
    together.

    Args:
      tensor: a tensor of shape `(d0, ..., d(N-1))`.
      split_dim: an integer from 1 to N-1, index of the dimension to group
        dimensions before (excluding) and after (including).

    Returns:
      Tensor of shape
      `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
    """
    shape = tf.shape(tensor)
    in_dims = shape[:split_dim]
    out_dims = shape[split_dim:]

    in_size = tf.reduce_prod(in_dims)
    out_size = tf.reduce_prod(out_dims)

    return tf.reshape(tensor, (in_size, out_size))
