# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-import-not-at-top
"""Utilities related to model visualization."""

import os
import sys
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export


try:
  # pydot-ng is a fork of pydot that is better maintained.
  import pydot_ng as pydot
except ImportError:
  # pydotplus is an improved version of pydot
  try:
    import pydotplus as pydot
  except ImportError:
    # Fall back on pydot if necessary.
    try:
      import pydot
    except ImportError:
      pydot = None


def check_pydot():
  """Returns True if PyDot and Graphviz are available."""
  if pydot is None:
    return False
  try:
    # Attempt to create an image of a blank graph
    # to check the pydot/graphviz installation.
    pydot.Dot.create(pydot.Dot())
    return True
  except (OSError, pydot.InvocationException):
    return False


def is_wrapped_model(layer):
  from tensorflow.python.keras.engine import functional
  from tensorflow.python.keras.layers import wrappers
  return (isinstance(layer, wrappers.Wrapper) and
          isinstance(layer.layer, functional.Functional))


def add_edge(dot, src, dst):
  if not dot.get_edge(src, dst):
    dot.add_edge(pydot.Edge(src, dst))


@keras_export('keras.utils.model_to_dot')
def model_to_dot(model,
                 show_shapes=False,
                 show_dtype=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False):
  """Convert a Keras model to dot format.

  Args:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
  from tensorflow.python.keras.layers import wrappers
  from tensorflow.python.keras.engine import sequential
  from tensorflow.python.keras.engine import functional

  if not check_pydot():
    message = (
        'You must install pydot (`pip install pydot`) '
        'and install graphviz '
        '(see instructions at https://graphviz.gitlab.io/download/) ',
        'for plot_model/model_to_dot to work.')
    if 'IPython.core.magics.namespace' in sys.modules:
      # We don't raise an exception here in order to avoid crashing notebook
      # tests where graphviz is not available.
      print(message)
      return
    else:
      raise ImportError(message)

  if subgraph:
    dot = pydot.Cluster(style='dashed', graph_name=model.name)
    dot.set('label', model.name)
    dot.set('labeljust', 'l')
  else:
    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set('dpi', dpi)
    dot.set_node_defaults(shape='record')

  sub_n_first_node = {}
  sub_n_last_node = {}
  sub_w_first_node = {}
  sub_w_last_node = {}

  layers = model.layers
  if not model._is_graph_network:
    node = pydot.Node(str(id(model)), label=model.name)
    dot.add_node(node)
    return dot
  elif isinstance(model, sequential.Sequential):
    if not model.built:
      model.build()
    layers = super(sequential.Sequential, model).layers

  # Create graph nodes.
  for i, layer in enumerate(layers):
    layer_id = str(id(layer))

    # Append a wrapped layer's label to node's label, if it exists.
    layer_name = layer.name
    class_name = layer.__class__.__name__

    if isinstance(layer, wrappers.Wrapper):
      if expand_nested and isinstance(layer.layer,
                                      functional.Functional):
        submodel_wrapper = model_to_dot(
            layer.layer,
            show_shapes,
            show_dtype,
            show_layer_names,
            rankdir,
            expand_nested,
            subgraph=True)
        # sub_w : submodel_wrapper
        sub_w_nodes = submodel_wrapper.get_nodes()
        sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
        sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
        dot.add_subgraph(submodel_wrapper)
      else:
        layer_name = '{}({})'.format(layer_name, layer.layer.name)
        child_class_name = layer.layer.__class__.__name__
        class_name = '{}({})'.format(class_name, child_class_name)

    if expand_nested and isinstance(layer, functional.Functional):
      submodel_not_wrapper = model_to_dot(
          layer,
          show_shapes,
          show_dtype,
          show_layer_names,
          rankdir,
          expand_nested,
          subgraph=True)
      # sub_n : submodel_not_wrapper
      sub_n_nodes = submodel_not_wrapper.get_nodes()
      sub_n_first_node[layer.name] = sub_n_nodes[0]
      sub_n_last_node[layer.name] = sub_n_nodes[-1]
      dot.add_subgraph(submodel_not_wrapper)

    # Create node's label.
    if show_layer_names:
      label = '{}: {}'.format(layer_name, class_name)
    else:
      label = class_name

    # Rebuild the label as a table including the layer's dtype.
    if show_dtype:

      def format_dtype(dtype):
        if dtype is None:
          return '?'
        else:
          return str(dtype)

      label = '%s|%s' % (label, format_dtype(layer.dtype))

    # Rebuild the label as a table including input/output shapes.
    if show_shapes:

      def format_shape(shape):
        return str(shape).replace(str(None), 'None')

      try:
        outputlabels = format_shape(layer.output_shape)
      except AttributeError:
        outputlabels = '?'
      if hasattr(layer, 'input_shape'):
        inputlabels = format_shape(layer.input_shape)
      elif hasattr(layer, 'input_shapes'):
        inputlabels = ', '.join(
            [format_shape(ishape) for ishape in layer.input_shapes])
      else:
        inputlabels = '?'
      label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                     inputlabels,
                                                     outputlabels)

    if not expand_nested or not isinstance(
        layer, functional.Functional):
      node = pydot.Node(layer_id, label=label)
      dot.add_node(node)

  # Connect nodes with edges.
  for layer in layers:
    layer_id = str(id(layer))
    for i, node in enumerate(layer._inbound_nodes):
      node_key = layer.name + '_ib-' + str(i)
      if node_key in model._network_nodes:
        for inbound_layer in nest.flatten(node.inbound_layers):
          inbound_layer_id = str(id(inbound_layer))
          if not expand_nested:
            assert dot.get_node(inbound_layer_id)
            assert dot.get_node(layer_id)
            add_edge(dot, inbound_layer_id, layer_id)
          else:
            # if inbound_layer is not Model or wrapped Model
            if (not isinstance(inbound_layer,
                               functional.Functional) and
                not is_wrapped_model(inbound_layer)):
              # if current layer is not Model or wrapped Model
              if (not isinstance(layer, functional.Functional) and
                  not is_wrapped_model(layer)):
                assert dot.get_node(inbound_layer_id)
                assert dot.get_node(layer_id)
                add_edge(dot, inbound_layer_id, layer_id)
              # if current layer is Model
              elif isinstance(layer, functional.Functional):
                add_edge(dot, inbound_layer_id,
                         sub_n_first_node[layer.name].get_name())
              # if current layer is wrapped Model
              elif is_wrapped_model(layer):
                add_edge(dot, inbound_layer_id, layer_id)
                name = sub_w_first_node[layer.layer.name].get_name()
                add_edge(dot, layer_id, name)
            # if inbound_layer is Model
            elif isinstance(inbound_layer, functional.Functional):
              name = sub_n_last_node[inbound_layer.name].get_name()
              if isinstance(layer, functional.Functional):
                output_name = sub_n_first_node[layer.name].get_name()
                add_edge(dot, name, output_name)
              else:
                add_edge(dot, name, layer_id)
            # if inbound_layer is wrapped Model
            elif is_wrapped_model(inbound_layer):
              inbound_layer_name = inbound_layer.layer.name
              add_edge(dot,
                       sub_w_last_node[inbound_layer_name].get_name(),
                       layer_id)
  return dot


@keras_export('keras.utils.plot_model')
def plot_model(model,
               to_file='model.png',
               show_shapes=False,
               show_dtype=False,
               show_layer_names=True,
               rankdir='TB',
               expand_nested=False,
               dpi=96):
  """Converts a Keras model to dot format and save to a file.

  Example:

  ```python
  input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
  x = tf.keras.layers.Embedding(
      output_dim=512, input_dim=10000, input_length=100)(input)
  x = tf.keras.layers.LSTM(32)(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)
  output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
  model = tf.keras.Model(inputs=[input], outputs=[output])
  dot_img_file = '/tmp/model_1.png'
  tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
  ```

  Args:
    model: A Keras model instance
    to_file: File name of the plot image.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: Whether to expand nested models into clusters.
    dpi: Dots per inch.

  Returns:
    A Jupyter notebook Image object if Jupyter is installed.
    This enables in-line display of the model plots in notebooks.
  """
  dot = model_to_dot(
      model,
      show_shapes=show_shapes,
      show_dtype=show_dtype,
      show_layer_names=show_layer_names,
      rankdir=rankdir,
      expand_nested=expand_nested,
      dpi=dpi)
  to_file = path_to_string(to_file)
  if dot is None:
    return
  _, extension = os.path.splitext(to_file)
  if not extension:
    extension = 'png'
  else:
    extension = extension[1:]
  # Save image to disk.
  dot.write(to_file, format=extension)
  # Return the image as a Jupyter Image object, to be displayed in-line.
  # Note that we cannot easily detect whether the code is running in a
  # notebook, and thus we always return the Image if Jupyter is available.
  if extension != 'pdf':
    try:
      from IPython import display
      return display.Image(filename=to_file)
    except ImportError:
      pass
