# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests utils for preprocessing layers."""

import collections

import numpy as np
import tensorflow.compat.v2 as tf


class ArrayLike:
    def __init__(self, values):
        self.values = values

    def __array__(self):
        return np.array(self.values)


class PreprocessingLayerTest(tf.test.TestCase):
    """Base test class for preprocessing layer API validation."""

    # TODO(b/137303934): Consider incorporating something like this Close vs All
    # behavior into core tf.test.TestCase.

    def assertAllCloseOrEqual(self, a, b, msg=None):
        """Asserts that elements are close (if numeric) or equal (if string)."""
        if a is None or b is None:
            self.assertAllEqual(a, b, msg=msg)
        elif isinstance(a, (list, tuple)):
            self.assertEqual(len(a), len(b))
            for a_value, b_value in zip(a, b):
                self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
        elif isinstance(a, collections.abc.Mapping):
            self.assertEqual(len(a), len(b))
            for key, a_value in a.items():
                b_value = b[key]
                error_message = "{} ({})".format(msg, key) if msg else None
                self.assertAllCloseOrEqual(a_value, b_value, error_message)
        elif (
            isinstance(a, float)
            or hasattr(a, "dtype")
            and np.issubdtype(a.dtype, np.number)
        ):
            self.assertAllClose(a, b, msg=msg)
        else:
            self.assertAllEqual(a, b, msg=msg)

    def assert_extracted_output_equal(self, combiner, acc1, acc2, msg=None):
        data_1 = combiner.extract(acc1)
        data_2 = combiner.extract(acc2)
        self.assertAllCloseOrEqual(data_1, data_2, msg=msg)

    # This is an injection seam so that tests like TextVectorizationTest can
    # define their own methods for asserting that accumulators are equal.
    compare_accumulators = assertAllCloseOrEqual

    def validate_accumulator_computation(self, combiner, data, expected):
        """Validate that various combinations of compute and merge are
        identical."""
        if len(data) < 4:
            raise AssertionError(
                f"Data must have at least 4 elements. Received "
                f"len(data)={len(data)}."
            )
        data_0 = np.array([data[0]])
        data_1 = np.array([data[1]])
        data_2 = np.array(data[2:])

        single_compute = combiner.compute(data)

        all_merge = combiner.merge(
            [
                combiner.compute(data_0),
                combiner.compute(data_1),
                combiner.compute(data_2),
            ]
        )

        self.compare_accumulators(
            single_compute,
            all_merge,
            msg="Sharding data should not change the data output.",
        )

        unordered_all_merge = combiner.merge(
            [
                combiner.compute(data_1),
                combiner.compute(data_2),
                combiner.compute(data_0),
            ]
        )
        self.compare_accumulators(
            all_merge,
            unordered_all_merge,
            msg="The order of merge arguments should not change the data "
            "output.",
        )

        hierarchical_merge = combiner.merge(
            [
                combiner.compute(data_1),
                combiner.merge(
                    [combiner.compute(data_2), combiner.compute(data_0)]
                ),
            ]
        )
        self.compare_accumulators(
            all_merge,
            hierarchical_merge,
            msg="Nesting merge arguments should not change the data output.",
        )

        nested_compute = combiner.compute(
            data_0, combiner.compute(data_1, combiner.compute(data_2))
        )
        self.compare_accumulators(
            all_merge,
            nested_compute,
            msg="Nesting compute arguments should not change the data output.",
        )

        mixed_compute = combiner.merge(
            [
                combiner.compute(data_0),
                combiner.compute(data_1, combiner.compute(data_2)),
            ]
        )
        self.compare_accumulators(
            all_merge,
            mixed_compute,
            msg="Mixing merge and compute calls should not change the data "
            "output.",
        )

        single_merge = combiner.merge(
            [
                combiner.merge([combiner.compute(data_0)]),
                combiner.compute(data_1, combiner.compute(data_2)),
            ]
        )
        self.compare_accumulators(
            all_merge,
            single_merge,
            msg="Calling merge with a data length of 1 should not change "
            "the data output.",
        )

        self.compare_accumulators(
            expected,
            all_merge,
            msg="Calculated accumulators "
            "did not match expected accumulator.",
        )

    def validate_accumulator_extract(self, combiner, data, expected):
        """Validate that the expected results of computing and extracting."""
        acc = combiner.compute(data)
        extracted_data = combiner.extract(acc)
        self.assertAllCloseOrEqual(expected, extracted_data)

    def validate_accumulator_extract_and_restore(
        self, combiner, data, expected
    ):
        """Validate that the extract<->restore loop loses no data."""
        acc = combiner.compute(data)
        extracted_data = combiner.extract(acc)
        restored_acc = combiner.restore(extracted_data)
        self.assert_extracted_output_equal(combiner, acc, restored_acc)
        self.assertAllCloseOrEqual(expected, combiner.extract(restored_acc))

    def validate_accumulator_serialize_and_deserialize(
        self, combiner, data, expected
    ):
        """Validate that the serialize<->deserialize loop loses no data."""
        acc = combiner.compute(data)
        serialized_data = combiner.serialize(acc)
        deserialized_data = combiner.deserialize(serialized_data)
        self.compare_accumulators(acc, deserialized_data)
        self.compare_accumulators(expected, deserialized_data)

    def validate_accumulator_uniqueness(self, combiner, data):
        """Validate that every call to compute creates a unique accumulator."""
        acc = combiner.compute(data)
        acc2 = combiner.compute(data)
        self.assertIsNot(acc, acc2)
        self.compare_accumulators(acc, acc2)
