# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AST node annotation support.

Adapted from Tangent.
"""

import enum

# pylint:disable=g-bad-import-order

import gast
# pylint:enable=g-bad-import-order


# TODO(mdan): Shorten the names.
# These names are heavily used, and anno.blaa
# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.


class NoValue(enum.Enum):
  """Base class for different types of AST annotations."""

  def of(self, node, default=None):
    return getanno(node, self, default=default)

  def add_to(self, node, value):
    setanno(node, self, value)

  def exists(self, node):
    return hasanno(node, self)

  def __repr__(self):
    return str(self.name)


class Basic(NoValue):
  """Container for basic annotation keys.

  The enum values are used strictly for documentation purposes.
  """

  QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
  SKIP_PROCESSING = (
      'This node should be preserved as is and not processed any further.')
  INDENT_BLOCK_REMAINDER = (
      'When a node is annotated with this, the remainder of the block should'
      ' be indented below it. The annotation contains a tuple'
      ' (new_body, name_map), where `new_body` is the new indented block and'
      ' `name_map` allows renaming symbols.')
  ORIGIN = ('Information about the source code that converted code originated'
            ' from. See origin_information.py.')
  DIRECTIVES = ('User directives associated with a statement or a variable.'
                ' Typically, they affect the immediately-enclosing statement.')

  EXTRA_LOOP_TEST = (
      'A special annotation containing additional test code to be executed in'
      ' for loops.')


class Static(NoValue):
  """Container for static analysis annotation keys.

  The enum values are used strictly for documentation purposes.
  """

  # Symbols
  # These flags are boolean.
  IS_PARAM = 'Symbol is a parameter to the function being analyzed.'

  # Scopes
  # Scopes are represented by objects of type activity.Scope.
  SCOPE = 'The scope for the annotated node. See activity.py.'
  # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
  ARGS_SCOPE = 'The scope for the argument list of a function call.'
  COND_SCOPE = 'The scope for the test node of a conditional statement.'
  BODY_SCOPE = (
      'The scope for the main body of a statement (True branch for if '
      'statements, main body for loops).')
  ORELSE_SCOPE = (
      'The scope for the orelse body of a statement (False branch for if '
      'statements, orelse body for loops).')

  # Static analysis annotations.
  DEFINITIONS = (
      'Reaching definition information. See reaching_definitions.py.')
  ORIG_DEFINITIONS = (
      'The value of DEFINITIONS that applied to the original code before any'
      ' conversion.')
  DEFINED_FNS_IN = (
      'Local function definitions that may exist when exiting the node. See'
      ' reaching_fndefs.py')
  DEFINED_VARS_IN = (
      'Symbols defined when entering the node. See reaching_definitions.py.')
  LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
  LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
  TYPES = 'Static type information. See type_inference.py.'
  CLOSURE_TYPES = 'Types of closure symbols at each detected call site.'
  VALUE = 'Static value information. See type_inference.py.'


FAIL = object()


def keys(node, field_name='___pyct_anno'):
  if not hasattr(node, field_name):
    return frozenset()
  return frozenset(getattr(node, field_name).keys())


def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
  if (default is FAIL or (hasattr(node, field_name) and
                          (key in getattr(node, field_name)))):
    return getattr(node, field_name)[key]
  return default


def hasanno(node, key, field_name='___pyct_anno'):
  return hasattr(node, field_name) and key in getattr(node, field_name)


def setanno(node, key, value, field_name='___pyct_anno'):
  annotations = getattr(node, field_name, {})
  setattr(node, field_name, annotations)
  annotations[key] = value

  # So that the annotations survive gast_to_ast() and ast_to_gast()
  if field_name not in node._fields:
    node._fields += (field_name,)


def delanno(node, key, field_name='___pyct_anno'):
  annotations = getattr(node, field_name)
  del annotations[key]
  if not annotations:
    delattr(node, field_name)
    node._fields = tuple(f for f in node._fields if f != field_name)


def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
  if hasanno(from_node, key, field_name=field_name):
    setanno(
        to_node,
        key,
        getanno(from_node, key, field_name=field_name),
        field_name=field_name)


def dup(node, copy_map, field_name='___pyct_anno'):
  """Recursively copies annotations in an AST tree.

  Args:
    node: ast.AST
    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
        key. All annotations with the source key will be copied to identical
        annotations with the destination key.
    field_name: str
  """
  for n in gast.walk(node):
    for k in copy_map:
      if hasanno(n, k, field_name):
        setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
