# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras utilities for DTensor unit test."""

import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized

# isort: off
from tensorflow.dtensor.python import api as dtensor_api
from tensorflow.python.eager import context

_DEFAULT_GPU_MEMORY_LIMIT = 200  # MB


class DTensorBaseTest(tf.test.TestCase, parameterized.TestCase):
    """Provides comparison helper for dtensor vs local results."""

    @classmethod
    def setUpClass(cls):
        super(DTensorBaseTest, cls).setUpClass()

    def tearDown(self):
        super().tearDown()
        # Make sure all async ops finish.
        context.async_wait()

        # TODO(hthu): Remove the reset once we fixed the CopyToMesh with
        # DefaultMesh placement issue.
        reset_dtensor()

    @staticmethod
    def configTestMesh(device_type_mesh_map):
        """Configs corresponding mesh given test context.

        If runs on a CPU mesh, set virtual device on CPU.
        If runs on a GPU mesh, sets virtual device on GPU with proper memory
        limits.
        if runs on a TPU mesh, initializes TPU system.

        Args:
          device_type_mesh_map: A dictionary containing device_type -> mesh
            mapping.

        Returns:
          A properly configured mesh for use in test.
        """
        reset_context()

        def get_mesh(device_type):
            mesh = device_type_mesh_map.get(device_type, None)
            if mesh is None:
                raise ValueError(
                    "Requires a %s mesh to run test on %s."
                    % (device_type, device_type)
                )
            return mesh

        mesh = None
        if tf.config.list_physical_devices("GPU"):
            mesh = get_mesh("GPU")
            reset_logical_devices("GPU", np.prod(mesh.shape()))
        else:
            mesh = get_mesh("CPU")
            reset_logical_devices("CPU", np.prod(mesh.shape()))

        context.ensure_initialized()
        return mesh


def create_device_array(shape, device_type):
    device_count = np.prod(shape)
    return np.asarray(
        [
            tf.DeviceSpec(
                job="localhost/replica:0/task:0",
                device_type=device_type,
                device_index=i,
            )
            for i in range(device_count)
        ]
    ).reshape(shape)


def create_device_list(shape, device_type):
    devices = create_device_array(shape, device_type)
    return np.ravel(devices).tolist()


def create_device_ids_array(shape):
    device_count = np.prod(shape)
    return np.arange(device_count).reshape(shape)


def reset_context():
    context._reset_context()


def reset_logical_devices(device_type, count):
    """Resets logical devices for CPU/GPU.

    Logical devices can only be instantiated once on a particular context. For
    now, context re-use is triggering some function duplication errors, so we
    reset the context on each call.

    Args:
      device_type: The device_type to reset.
      count: numbers of virtual device to reset to.
    """
    reset_context()
    devices = tf.config.list_physical_devices(device_type)
    if device_type.upper() == "CPU":
        tf.config.set_logical_device_configuration(
            devices[0],
            [
                tf.config.LogicalDeviceConfiguration(),
            ]
            * count,
        )
    elif device_type.upper() == "GPU":
        tf.config.set_logical_device_configuration(
            devices[0],
            [
                tf.config.LogicalDeviceConfiguration(
                    memory_limit=_DEFAULT_GPU_MEMORY_LIMIT
                ),
            ]
            * count,
        )
    else:
        raise ValueError(
            "resetting logical device for non-supported device type : "
            "%s" % device_type
        )


def reset_dtensor():
    dtensor_api._reset()
