# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility object to handler partial batches for TPUStrategy."""

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend


class PartialBatchPaddingHandler:
    """A container that holds info about partial batches for `predict()`."""

    def __init__(self, output_shape):
        self.padded_batch_size = 0
        self.padding_mask = tf.zeros(0)
        self.output_shape = output_shape

    def get_real_batch_size(self, dataset_batch):
        """Returns the number of elements in a potentially partial batch."""
        if isinstance(dataset_batch, (tuple, list)):
            dataset_batch = dataset_batch[0]

        assert tf.nest.flatten(dataset_batch)

        def _find_any_tensor(batch_features):
            tensors = [
                x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x)
            ]
            if not tensors:
                raise ValueError("Cannot find any Tensor in features dict.")
            return tensors[0]

        return backend.cast(
            backend.shape(_find_any_tensor(dataset_batch))[0], dtype="int64"
        )

    def update_mask(self, padding_mask, dataset_batch):
        """Calculate and cache the amount of padding required for a batch."""
        original_batch_size = self.get_real_batch_size(dataset_batch)
        missing_count = self.padded_batch_size - original_batch_size
        mask = backend.concatenate(
            [tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0
        )
        return backend.concatenate([padding_mask, mask], axis=0)

    def pad_batch(self, *dataset_batch_elements):
        """Pads the batch dimension of a tensor to the complete batch size."""

        def _pad(batch):
            """Helper function to pad nested data within each batch elements."""
            padded_dict_batch = {}
            if isinstance(batch, dict):
                for key, value in batch.items():
                    padded_dict_batch[key] = _pad(value)
                return padded_dict_batch

            rank = len(batch.shape)
            assert rank > 0
            missing_count = self.padded_batch_size - self.get_real_batch_size(
                batch
            )
            padding = backend.stack(
                [[0, missing_count]] + [[0, 0]] * (rank - 1)
            )
            return tf.pad(batch, padding, "constant")

        if len(dataset_batch_elements) == 1:
            return _pad(dataset_batch_elements[0])

        batch_elements = []
        for batch_element in dataset_batch_elements:
            batch_elements.append(_pad(batch_element))
        return tuple(batch_elements)

    def apply_mask(self, prediction_result):
        """Removes prediction output that corresponds to padded input."""
        padding_mask = backend.get_value(self.padding_mask)
        assert len(padding_mask.shape) == 1

        if len(self.output_shape) == 1:
            prediction = np.take(
                prediction_result,
                np.nonzero(padding_mask[: len(prediction_result)]),
                axis=0,
            )
            if prediction.shape[0] == 1:
                prediction = np.squeeze(prediction, axis=0)
            return prediction

        else:
            predictions = []
            for i in range(len(self.output_shape)):
                prediction = prediction_result[i]
                prediction = np.take(
                    prediction,
                    np.nonzero(padding_mask[: len(prediction)]),
                    axis=0,
                )
                predictions.append(np.squeeze(prediction))

            return predictions
