# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Callbacks: utilities called at certain points during model training."""

import os

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend
from keras import callbacks

# isort: off
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export


@keras_export(v1=["keras.callbacks.TensorBoard"])
class TensorBoard(callbacks.TensorBoard):

    """Enable visualizations for TensorBoard.

    TensorBoard is a visualization tool provided with TensorFlow.

    This callback logs events for TensorBoard, including:
    * Metrics summary plots
    * Training graph visualization
    * Activation histograms
    * Sampled profiling

    If you have installed TensorFlow with pip, you should be able
    to launch TensorBoard from the command line:

    ```sh
    tensorboard --logdir=path_to_your_logs
    ```

    You can find more information about TensorBoard
    [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).

    Args:
        log_dir: the path of the directory where to save the log files to be
          parsed by TensorBoard.
        histogram_freq: frequency (in epochs) at which to compute activation and
          weight histograms for the layers of the model. If set to 0, histograms
          won't be computed. Validation data (or split) must be specified for
          histogram visualizations.
        write_graph: whether to visualize the graph in TensorBoard. The log file
          can become quite large when write_graph is set to True.
        write_grads: whether to visualize gradient histograms in TensorBoard.
          `histogram_freq` must be greater than 0.
        batch_size: size of batch of inputs to feed to the network for
          histograms computation.
        write_images: whether to write model weights to visualize as image in
          TensorBoard.
        embeddings_freq: frequency (in epochs) at which selected embedding
          layers will be saved. If set to 0, embeddings won't be computed. Data
          to be visualized in TensorBoard's Embedding tab must be passed as
          `embeddings_data`.
        embeddings_layer_names: a list of names of layers to keep eye on. If
          None or empty list all the embedding layer will be watched.
        embeddings_metadata: a dictionary which maps layer name to a file name
          in which metadata for this embedding layer is saved.
            [Here are details](
              https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
              about metadata files format. In case if the same metadata file is
              used for all embedding layers, string can be passed.
        embeddings_data: data to be embedded at layers specified in
          `embeddings_layer_names`. Numpy array (if the model has a single
          input) or list of Numpy arrays (if the model has multiple inputs).
          Learn more about embeddings [in this guide](
          https://www.tensorflow.org/programmers_guide/embedding).
        update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
          writes the losses and metrics to TensorBoard after each batch. The
          same applies for `'epoch'`. If using an integer, let's say `1000`, the
          callback will write the metrics and losses to TensorBoard every 1000
          samples. Note that writing too frequently to TensorBoard can slow down
          your training.
        profile_batch: Profile the batch to sample compute characteristics. By
          default, it will profile the second batch. Set profile_batch=0 to
          disable profiling.

    Raises:
        ValueError: If histogram_freq is set and no validation data is provided.

    @compatibility(eager)
    Using the `TensorBoard` callback will work when eager execution is enabled,
    with the restriction that outputting histogram summaries of weights and
    gradients is not supported. Consequently, `histogram_freq` will be ignored.
    @end_compatibility
    """

    def __init__(
        self,
        log_dir="./logs",
        histogram_freq=0,
        batch_size=32,
        write_graph=True,
        write_grads=False,
        write_images=False,
        embeddings_freq=0,
        embeddings_layer_names=None,
        embeddings_metadata=None,
        embeddings_data=None,
        update_freq="epoch",
        profile_batch=2,
    ):
        # Don't call super's init since it is an eager-only version.
        callbacks.Callback.__init__(self)
        self.log_dir = log_dir
        self.histogram_freq = histogram_freq
        if self.histogram_freq and tf.executing_eagerly():
            logging.warning(
                UserWarning(
                    "Weight and gradient histograms not supported for eager"
                    "execution, setting `histogram_freq` to `0`."
                )
            )
            self.histogram_freq = 0
        self.merged = None
        self.write_graph = write_graph
        self.write_grads = write_grads
        self.write_images = write_images
        self.batch_size = batch_size
        self._current_batch = 0
        self._total_batches_seen = 0
        self._total_val_batches_seen = 0
        self.embeddings_freq = embeddings_freq
        self.embeddings_layer_names = embeddings_layer_names
        self.embeddings_metadata = embeddings_metadata
        self.embeddings_data = embeddings_data
        if update_freq == "batch":
            self.update_freq = 1
        else:
            self.update_freq = update_freq
        self._samples_seen = 0
        self._samples_seen_at_last_write = 0
        # TODO(fishx): Add a link to the full profiler tutorial.
        self._profile_batch = profile_batch
        # True when the profiler was successfully started by this callback.
        # We track the status here to make sure callbacks do not interfere with
        # each other. The callback will only stop the profiler it started.
        self._profiler_started = False

        # TensorBoard should only write summaries on the chief when in a
        # Multi-Worker setting.
        self._chief_worker_only = True

    def _init_writer(self, model):
        """Sets file writer."""
        if tf.executing_eagerly():
            self.writer = tf.summary.create_file_writer(self.log_dir)
            if not model.run_eagerly and self.write_graph:
                with self.writer.as_default():
                    tf.summary.graph(backend.get_graph())
        elif self.write_graph:
            self.writer = tf.compat.v1.summary.FileWriter(
                self.log_dir, backend.get_graph()
            )
        else:
            self.writer = tf.compat.v1.summary.FileWriter(self.log_dir)

    def _make_histogram_ops(self, model):
        """Defines histogram ops when histogram_freq > 0."""
        # only make histogram summary op if it hasn't already been made
        if self.histogram_freq and self.merged is None:
            for layer in self.model.layers:
                for weight in layer.weights:
                    mapped_weight_name = weight.name.replace(":", "_")
                    tf.compat.v1.summary.histogram(mapped_weight_name, weight)
                    if self.write_images:
                        w_img = tf.compat.v1.squeeze(weight)
                        shape = tuple(w_img.shape)
                        if len(shape) == 2:  # dense layer kernel case
                            if shape[0] > shape[1]:
                                w_img = tf.compat.v1.transpose(w_img)
                                shape = tuple(w_img.shape)
                            w_img = tf.reshape(
                                w_img, [1, shape[0], shape[1], 1]
                            )
                        elif len(shape) == 3:  # convnet case
                            if backend.image_data_format() == "channels_last":
                                # switch to channels_first to display
                                # every kernel as a separate image
                                w_img = tf.compat.v1.transpose(
                                    w_img, perm=[2, 0, 1]
                                )
                                shape = tuple(w_img.shape)
                            w_img = tf.reshape(
                                w_img, [shape[0], shape[1], shape[2], 1]
                            )
                        elif len(shape) == 1:  # bias case
                            w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
                        else:
                            # not possible to handle 3D convnets etc.
                            continue

                        shape = tuple(w_img.shape)
                        assert len(shape) == 4 and shape[-1] in [1, 3, 4]
                        tf.compat.v1.summary.image(mapped_weight_name, w_img)

                if self.write_grads:
                    for weight in layer.trainable_weights:
                        mapped_weight_name = weight.name.replace(":", "_")
                        grads = model.optimizer.get_gradients(
                            model.total_loss, weight
                        )

                        def is_indexed_slices(grad):
                            return type(grad).__name__ == "IndexedSlices"

                        grads = [
                            grad.values if is_indexed_slices(grad) else grad
                            for grad in grads
                        ]
                        tf.compat.v1.summary.histogram(
                            "{}_grad".format(mapped_weight_name), grads
                        )

                if hasattr(layer, "output"):
                    if isinstance(layer.output, list):
                        for i, output in enumerate(layer.output):
                            tf.compat.v1.summary.histogram(
                                "{}_out_{}".format(layer.name, i), output
                            )
                    else:
                        tf.compat.v1.summary.histogram(
                            "{}_out".format(layer.name), layer.output
                        )

    def set_model(self, model):
        """Sets Keras model and creates summary ops."""

        self.model = model
        self._init_writer(model)
        # histogram summaries only enabled in graph mode
        if not tf.executing_eagerly():
            self._make_histogram_ops(model)
            self.merged = tf.compat.v1.summary.merge_all()

        # If both embedding_freq and embeddings_data are available, we will
        # visualize embeddings.
        if self.embeddings_freq and self.embeddings_data is not None:
            # Avoid circular dependency.
            from keras.engine import (
                training_utils_v1,
            )

            self.embeddings_data = training_utils_v1.standardize_input_data(
                self.embeddings_data, model.input_names
            )

            # If embedding_layer_names are not provided, get all of the
            # embedding layers from the model.
            embeddings_layer_names = self.embeddings_layer_names
            if not embeddings_layer_names:
                embeddings_layer_names = [
                    layer.name
                    for layer in self.model.layers
                    if type(layer).__name__ == "Embedding"
                ]

            self.assign_embeddings = []
            embeddings_vars = {}

            self.batch_id = batch_id = tf.compat.v1.placeholder(tf.int32)
            self.step = step = tf.compat.v1.placeholder(tf.int32)

            for layer in self.model.layers:
                if layer.name in embeddings_layer_names:
                    embedding_input = self.model.get_layer(layer.name).output
                    embedding_size = np.prod(embedding_input.shape[1:])
                    embedding_input = tf.reshape(
                        embedding_input, (step, int(embedding_size))
                    )
                    shape = (
                        self.embeddings_data[0].shape[0],
                        int(embedding_size),
                    )
                    embedding = tf.Variable(
                        tf.zeros(shape), name=layer.name + "_embedding"
                    )
                    embeddings_vars[layer.name] = embedding
                    batch = tf.compat.v1.assign(
                        embedding[batch_id : batch_id + step], embedding_input
                    )
                    self.assign_embeddings.append(batch)

            self.saver = tf.compat.v1.train.Saver(
                list(embeddings_vars.values())
            )

            # Create embeddings_metadata dictionary
            if isinstance(self.embeddings_metadata, str):
                embeddings_metadata = {
                    layer_name: self.embeddings_metadata
                    for layer_name in embeddings_vars.keys()
                }
            else:
                # If embedding_metadata is already a dictionary
                embeddings_metadata = self.embeddings_metadata

            try:
                # isort: off
                from tensorboard.plugins import projector
            except ImportError:
                raise ImportError(
                    "Failed to import TensorBoard. Please make sure that "
                    'TensorBoard integration is complete."'
                )

            # TODO(psv): Add integration tests to test embedding visualization
            # with TensorBoard callback. We are unable to write a unit test for
            # this because TensorBoard dependency assumes TensorFlow package is
            # installed.
            config = projector.ProjectorConfig()
            for layer_name, tensor in embeddings_vars.items():
                embedding = config.embeddings.add()
                embedding.tensor_name = tensor.name

                if (
                    embeddings_metadata is not None
                    and layer_name in embeddings_metadata
                ):
                    embedding.metadata_path = embeddings_metadata[layer_name]

            projector.visualize_embeddings(self.writer, config)

    def _fetch_callback(self, summary):
        self.writer.add_summary(summary, self._total_val_batches_seen)
        self._total_val_batches_seen += 1

    def _write_custom_summaries(self, step, logs=None):
        """Writes metrics out as custom scalar summaries.

        Args:
            step: the global step to use for TensorBoard.
            logs: dict. Keys are scalar summary names, values are
                NumPy scalars.

        """
        logs = logs or {}
        if tf.executing_eagerly():
            # use v2 summary ops
            with self.writer.as_default(), tf.summary.record_if(True):
                for name, value in logs.items():
                    if isinstance(value, np.ndarray):
                        value = value.item()
                    tf.summary.scalar(name, value, step=step)
        else:
            # use FileWriter from v1 summary
            for name, value in logs.items():
                if isinstance(value, np.ndarray):
                    value = value.item()
                summary = tf.compat.v1.Summary()
                summary_value = summary.value.add()
                summary_value.simple_value = value
                summary_value.tag = name
                self.writer.add_summary(summary, step)
        self.writer.flush()

    def on_train_batch_begin(self, batch, logs=None):
        if self._total_batches_seen == self._profile_batch - 1:
            self._start_profiler()

    def on_train_batch_end(self, batch, logs=None):
        return self.on_batch_end(batch, logs)

    def on_test_begin(self, logs=None):
        pass

    def on_test_end(self, logs=None):
        pass

    def on_batch_end(self, batch, logs=None):
        """Writes scalar summaries for metrics on every training batch.

        Performs profiling if current batch is in profiler_batches.
        """
        # Don't output batch_size and batch number as TensorBoard summaries
        logs = logs or {}
        self._samples_seen += logs.get("size", 1)
        samples_seen_since = (
            self._samples_seen - self._samples_seen_at_last_write
        )
        if (
            self.update_freq != "epoch"
            and samples_seen_since >= self.update_freq
        ):
            batch_logs = {
                ("batch_" + k): v
                for k, v in logs.items()
                if k not in ["batch", "size", "num_steps"]
            }
            self._write_custom_summaries(self._total_batches_seen, batch_logs)
            self._samples_seen_at_last_write = self._samples_seen
        self._total_batches_seen += 1
        self._stop_profiler()

    def on_train_begin(self, logs=None):
        pass

    def on_epoch_begin(self, epoch, logs=None):
        """Add histogram op to Model eval_function callbacks, reset batch
        count."""

        # check if histogram summary should be run for this epoch
        if self.histogram_freq and epoch % self.histogram_freq == 0:

            # add the histogram summary op if it should run this epoch
            self.model._make_test_function()
            if self.merged not in self.model.test_function.fetches:
                self.model.test_function.fetches.append(self.merged)
                self.model.test_function.fetch_callbacks[
                    self.merged
                ] = self._fetch_callback

    def on_epoch_end(self, epoch, logs=None):
        """Checks if summary ops should run next epoch, logs scalar
        summaries."""

        # don't output batch_size and
        # batch number as TensorBoard summaries
        logs = {
            ("epoch_" + k): v
            for k, v in logs.items()
            if k not in ["batch", "size", "num_steps"]
        }
        if self.update_freq == "epoch":
            step = epoch
        else:
            step = self._samples_seen
        self._write_custom_summaries(step, logs)

        # pop the histogram summary op after each epoch
        if self.histogram_freq:

            if self.merged in self.model.test_function.fetches:
                self.model.test_function.fetches.remove(self.merged)
            if self.merged in self.model.test_function.fetch_callbacks:
                self.model.test_function.fetch_callbacks.pop(self.merged)

        if self.embeddings_data is None and self.embeddings_freq:
            raise ValueError(
                "To visualize embeddings, embeddings_data must " "be provided."
            )

        if self.embeddings_freq and self.embeddings_data is not None:
            if epoch % self.embeddings_freq == 0:
                # We need a second forward-pass here because we're passing
                # the `embeddings_data` explicitly. This design allows to pass
                # arbitrary data as `embeddings_data` and results from the fact
                # that we need to know the size of the `tf.Variable`s which
                # hold the embeddings in `set_model`. At this point, however,
                # the `validation_data` is not yet set.

                embeddings_data = self.embeddings_data
                n_samples = embeddings_data[0].shape[0]
                i = 0
                sess = backend.get_session()
                while i < n_samples:
                    step = min(self.batch_size, n_samples - i)
                    batch = slice(i, i + step)

                    if isinstance(self.model.input, list):
                        feed_dict = {
                            model_input: embeddings_data[idx][batch]
                            for idx, model_input in enumerate(self.model.input)
                        }
                    else:
                        feed_dict = {
                            self.model.input: embeddings_data[0][batch]
                        }

                    feed_dict.update({self.batch_id: i, self.step: step})

                    if not isinstance(backend.learning_phase(), int):
                        feed_dict[backend.learning_phase()] = False

                    sess.run(self.assign_embeddings, feed_dict=feed_dict)
                    self.saver.save(
                        sess,
                        os.path.join(self.log_dir, "keras_embedding.ckpt"),
                        epoch,
                    )

                    i += self.batch_size

    def on_train_end(self, logs=None):
        self._stop_profiler()
        self.writer.close()

    def _start_profiler(self):
        """Starts the profiler if currently inactive."""
        if self._profiler_started:
            return
        try:
            tf.profiler.experimental.start(logdir=self.log_dir)
            self._profiler_started = True
        except tf.errors.AlreadyExistsError as e:
            # Profiler errors should not be fatal.
            logging.error("Failed to start profiler: %s", e.message)

    def _stop_profiler(self):
        """Stops the profiler if currently active."""
        if not self._profiler_started:
            return
        try:
            tf.profiler.experimental.stop()
        except tf.errors.UnavailableError as e:
            # Profiler errors should not be fatal.
            logging.error("Failed to stop profiler: %s", e.message)
        finally:
            self._profiler_started = False
