# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts function definitions and lambdas by adding necessary boilerplate."""

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos


class _Function(object):

  def __init__(self):
    self.context_name = None


class FunctionTransformer(converter.Base):
  """Wraps function bodies around autograph-specific boilerplate."""

  def _function_scope_options(self, fn_scope):
    """Returns the options with which to create function scopes."""
    # Top-level function receive the options that were directly requested.
    # All others receive the options corresponding to a recursive conversion.
    # Note: this mainly controls the user_requested flag, which is important
    # primarily because the FunctionScope context also creates a
    # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
    # function_wrappers.py.
    if fn_scope.level == 2:
      return self.ctx.user.options
    return self.ctx.user.options.call_options()

  def visit_Lambda(self, node):
    with self.state[_Function] as fn_scope:
      node = self.generic_visit(node)

      # TODO(mdan): Fix the tests so that we can always add this decorator.
      if fn_scope.level > 2:
        return templates.replace_as_expression(
            'ag__.autograph_artifact(l)', l=node)

      scope = anno.getanno(node, anno.Static.SCOPE)
      function_context_name = self.ctx.namer.new_symbol('lscope',
                                                        scope.referenced)
      fn_scope.context_name = function_context_name
      anno.setanno(node, 'function_context_name', function_context_name)

      template = """
        ag__.with_function_scope(
            lambda function_context: body, function_context_name, options)
      """
      node.body = templates.replace_as_expression(
          template,
          options=self._function_scope_options(fn_scope).to_ast(),
          function_context=function_context_name,
          function_context_name=gast.Constant(function_context_name, kind=None),
          body=node.body)

      return node

  def visit_FunctionDef(self, node):
    with self.state[_Function] as fn_scope:
      scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

      function_context_name = self.ctx.namer.new_symbol('fscope',
                                                        scope.referenced)
      fn_scope.context_name = function_context_name
      anno.setanno(node, 'function_context_name', function_context_name)

      node = self.generic_visit(node)

      if fn_scope.level <= 2:
        # Top-level functions lose their decorator because the conversion is
        # always just-in-time and by the time it happens the decorators are
        # already set to be applied.
        node.decorator_list = []
      else:
        # TODO(mdan): Fix the tests so that we can always add this decorator.
        # Inner functions are converted already, so we insert a decorator to
        # prevent double conversion. Double conversion would work too, but this
        # saves the overhead.
        node.decorator_list.append(
            parser.parse_expression('ag__.autograph_artifact'))

      docstring_node = None
      if node.body:
        first_statement = node.body[0]
        if (isinstance(first_statement, gast.Expr) and
            isinstance(first_statement.value, gast.Constant)):
          docstring_node = first_statement
          node.body = node.body[1:]

      template = """
        with ag__.FunctionScope(
            function_name, context_name, options) as function_context:
          body
      """
      wrapped_body = templates.replace(
          template,
          function_name=gast.Constant(node.name, kind=None),
          context_name=gast.Constant(function_context_name, kind=None),
          options=self._function_scope_options(fn_scope).to_ast(),
          function_context=function_context_name,
          body=node.body)

      if docstring_node is not None:
        wrapped_body = [docstring_node] + wrapped_body

      node.body = wrapped_body

      return node


def transform(node, ctx):
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx, None)

  return FunctionTransformer(ctx).visit(node)
