# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Overloads all variable read operations."""

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates


class VariableAccessTransformer(converter.Base):
  """Rewrites basic symbol reads.

  This transformer rewrites variable reads with a "read" operator which allows
  tracking activity.

  Example:

  For a basic statement:

      a = b + c

  This is translated to:

      a = ld(b) + ld(c)

  Augmented assignment operations also introduce a `ld` operator:

      a += b

  The assignment target also receives an operator to properly represent the
  read:

      a = ld(a)
      a += ld(b)
  """

  def visit_Name(self, node):
    # Only the loads which existed in the original code are overloaded.
    if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
      return node
    if isinstance(node.ctx, gast.Load):
      node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
    return node

  def visit_Delete(self, node):
    node = self.generic_visit(node)

    rewrite_targets = []
    for tgt in node.targets:
      # Don't rewrite composites like `del a[0]`.
      if isinstance(tgt, gast.Name):
        rewrite_targets.append(tgt)

    if not rewrite_targets:
      return node

    results = []
    for tgt in rewrite_targets:
      template = """
        var_ = ag__.Undefined(var_name)
      """
      results.extend(templates.replace(
          template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None)))
    remaining_targets = [n for n in node.targets if n not in rewrite_targets]
    if remaining_targets:
      results.append(gast.Delete(targets=remaining_targets))

    return results

  def visit_AugAssign(self, node):
    if isinstance(node.target, gast.Name):
      template = """
        var_ = ag__.ld(var_)
        original
      """
      node = templates.replace(template, var_=node.target, original=node)
    else:
      node = self.generic_visit(node)
    return node


def transform(node, ctx):
  return VariableAccessTransformer(ctx).visit(node)
