# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting code to AST.

Adapted from Tangent.
"""

import ast
import inspect
import linecache
import re
import sys
import textwrap
import tokenize

import astunparse
import gast
import six

from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.util import tf_inspect


PY2_PREAMBLE = textwrap.dedent("""
""")
PY3_PREAMBLE = ''
MAX_SIZE = 0

if sys.version_info >= (3, 9):
  astunparse = ast

if sys.version_info >= (3,):
  STANDARD_PREAMBLE = PY3_PREAMBLE
  MAX_SIZE = sys.maxsize
else:
  STANDARD_PREAMBLE = PY2_PREAMBLE
  MAX_SIZE = sys.maxint

STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')


_LEADING_WHITESPACE = re.compile(r'\s*')


def _unfold_continuations(code_string):
  """Removes any backslash line continuations from the code."""
  return code_string.replace('\\\n', '')


def dedent_block(code_string):
  """Dedents a code so that its first line starts at row zero."""

  code_string = _unfold_continuations(code_string)

  token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline)

  block_indentation = None
  tokens = []
  try:
    for tok in token_gen:
      tokens.append(tok)
  except tokenize.TokenError:
    # Resolution of lambda functions may yield incomplete code, which can
    # in turn generate this error. We silently ignore this error because the
    # parser may still be able to deal with it.
    pass

  for tok in tokens:
    tok_type, tok_string, _, _, _ = tok
    if tok_type == tokenize.INDENT:
      block_indentation = tok_string
      block_level = len(block_indentation)
      break
    elif tok_type not in (
        tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
      block_indentation = ''
      break

  if not block_indentation:
    return code_string

  block_level = len(block_indentation)
  first_indent_uses_tabs = '\t' in block_indentation
  for i, tok in enumerate(tokens):
    tok_type, tok_string, _, _, _ = tok
    if tok_type == tokenize.INDENT:
      if ((' ' in tok_string and first_indent_uses_tabs)
          or ('\t' in tok_string and not first_indent_uses_tabs)):
        # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
        # See:
        # https://docs.python.org/3/reference/lexical_analysis.html#indentation
        raise errors.UnsupportedLanguageElementError(
            'code mixing tabs and spaces for indentation is not allowed')
      if len(tok_string) >= block_level:
        tok_string = tok_string[block_level:]
      tokens[i] = (tok_type, tok_string)

  new_code = tokenize.untokenize(tokens)

  # Note: untokenize respects the line structure, but not the whitespace within
  # lines. For example, `def foo()` may be untokenized as `def foo ()`
  # So instead of using the output of dedent, we match the leading whitespace
  # on each line.
  dedented_code = []
  for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
    original_indent = re.match(_LEADING_WHITESPACE, line).group()
    new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
    if len(original_indent) > len(new_indent):
      dedented_line = line[len(original_indent) - len(new_indent):]
    else:
      dedented_line = line
    dedented_code.append(dedented_line)
  new_code = '\n'.join(dedented_code)

  return new_code


def parse_entity(entity, future_features):
  """Returns the AST and source code of given entity.

  Args:
    entity: Any, Python function/method/class
    future_features: Iterable[Text], future features to use (e.g.
      'print_statement'). See
      https://docs.python.org/2/reference/simple_stmts.html#future

  Returns:
    gast.AST, Text: the parsed AST node; the source code that was parsed to
    generate the AST (including any prefixes that this function may have added).
  """
  if inspect_utils.islambda(entity):
    return _parse_lambda(entity)

  try:
    original_source = inspect_utils.getimmediatesource(entity)
  except OSError as e:
    raise errors.InaccessibleSourceCodeError(
        f'Unable to locate the source code of {entity}. Note that functions'
        ' defined in certain environments, like the interactive Python shell,'
        ' do not expose their source code. If that is the case, you should'
        ' define them in a .py source file. If you are certain the code is'
        ' graph-compatible, wrap the call using'
        f' @tf.autograph.experimental.do_not_convert. Original error: {e}')

  source = dedent_block(original_source)

  future_statements = tuple(
      'from __future__ import {}'.format(name) for name in future_features)
  source = '\n'.join(future_statements + (source,))

  return parse(source, preamble_len=len(future_features)), source


def _without_context(node, lines, minl, maxl):
  """Returns a clean node and source code without indenting and context."""
  for n in gast.walk(node):
    lineno = getattr(n, 'lineno', None)
    if lineno is not None:
      n.lineno = lineno - minl
    end_lineno = getattr(n, 'end_lineno', None)
    if end_lineno is not None:
      n.end_lineno = end_lineno - minl

  code_lines = lines[minl - 1:maxl]

  # Attempt to clean up surrounding context code.

  end_col_offset = getattr(node, 'end_col_offset', None)
  if end_col_offset is not None:
    # This is only available in 3.8.
    code_lines[-1] = code_lines[-1][:end_col_offset]

  col_offset = getattr(node, 'col_offset', None)
  if col_offset is None:
    # Older Python: try to find the "lambda" token. This is brittle.
    match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
    if match is not None:
      col_offset = match.start(0)

  if col_offset is not None:
    code_lines[0] = code_lines[0][col_offset:]

  code_block = '\n'.join([c.rstrip() for c in code_lines])

  return node, code_block


def _arg_name(node):
  if node is None:
    return None
  if isinstance(node, gast.Name):
    return node.id
  assert isinstance(node, str)
  return node


def _node_matches_argspec(node, func):
  """Returns True is node fits the argspec of func."""
  # TODO(mdan): Use just inspect once support for Python 2 is dropped.
  arg_spec = tf_inspect.getfullargspec(func)

  node_args = tuple(_arg_name(arg) for arg in node.args.args)
  if node_args != tuple(arg_spec.args):
    return False

  if arg_spec.varargs != _arg_name(node.args.vararg):
    return False

  if arg_spec.varkw != _arg_name(node.args.kwarg):
    return False

  node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
  if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
    return False

  return True


def _parse_lambda(lam):
  """Returns the AST and source code of given lambda function.

  Args:
    lam: types.LambdaType, Python function/method/class

  Returns:
    gast.AST, Text: the parsed AST node; the source code that was parsed to
    generate the AST (including any prefixes that this function may have added).
  """
  # TODO(mdan): Use a fast path if the definition is not multi-line.
  # We could detect that the lambda is in a multi-line expression by looking
  # at the surrounding code - an surrounding set of parentheses indicates a
  # potential multi-line definition.

  mod = inspect.getmodule(lam)
  f = inspect.getsourcefile(lam)
  def_line = lam.__code__.co_firstlineno

  # This method is more robust that just calling inspect.getsource(mod), as it
  # works in interactive shells, where getsource would fail. This is the
  # same procedure followed by inspect for non-modules:
  # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
  lines = linecache.getlines(f, mod.__dict__)
  source = ''.join(lines)

  # Narrow down to the last node starting before our definition node.
  all_nodes = parse(source, preamble_len=0, single_node=False)
  search_nodes = []
  for node in all_nodes:
    # Also include nodes without a line number, for safety. This is defensive -
    # we don't know whether such nodes might exist, and if they do, whether
    # they are not safe to skip.
    # TODO(mdan): Replace this check with an assertion or skip such nodes.
    if getattr(node, 'lineno', def_line) <= def_line:
      search_nodes.append(node)
    else:
      # Found a node starting past our lambda - can stop the search.
      break

  # Extract all lambda nodes from the shortlist.
  lambda_nodes = []
  for node in search_nodes:
    lambda_nodes.extend(
        n for n in gast.walk(node) if isinstance(n, gast.Lambda))

  # Filter down to lambda nodes which span our actual lambda.
  candidates = []
  for ln in lambda_nodes:
    minl, maxl = MAX_SIZE, 0
    for n in gast.walk(ln):
      minl = min(minl, getattr(n, 'lineno', minl))
      lineno = getattr(n, 'lineno', maxl)
      end_lineno = getattr(n, 'end_lineno', None)
      if end_lineno is not None:
        # end_lineno is more precise, but lineno should almost always work too.
        lineno = end_lineno
      maxl = max(maxl, lineno)
    if minl <= def_line <= maxl:
      candidates.append((ln, minl, maxl))

  # Happy path: exactly one node found.
  if len(candidates) == 1:
    (node, minl, maxl), = candidates  # pylint:disable=unbalanced-tuple-unpacking
    return _without_context(node, lines, minl, maxl)

  elif not candidates:
    lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes])
    raise errors.UnsupportedLanguageElementError(
        f'could not parse the source code of {lam}:'
        f' no matching AST found among candidates:\n{lambda_codes}')

  # Attempt to narrow down selection by signature is multiple nodes are found.
  matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
  if len(matches) == 1:
    (node, minl, maxl), = matches
    return _without_context(node, lines, minl, maxl)

  # Give up if could not narrow down to a single node.
  matches = '\n'.join(
      'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
      for i, (node, _, _) in enumerate(matches))
  raise errors.UnsupportedLanguageElementError(
      f'could not parse the source code of {lam}: found multiple definitions'
      ' with identical signatures at the location. This error'
      ' may be avoided by defining each lambda on a single line and with'
      f' unique argument names. The matching definitions were:\n{matches}')


# TODO(mdan): This should take futures as input instead.
def parse(src, preamble_len=0, single_node=True):
  """Returns the AST of given piece of code.

  Args:
    src: Text
    preamble_len: Int, indicates leading nodes in the parsed AST which should be
      dropped.
    single_node: Bool, whether `src` is assumed to be represented by exactly one
      AST node.

  Returns:
    ast.AST
  """
  module_node = gast.parse(src)
  nodes = module_node.body
  if preamble_len:
    nodes = nodes[preamble_len:]
  if single_node:
    if len(nodes) != 1:
      raise ValueError('expected exactly one node, got {}'.format(nodes))
    return nodes[0]
  return nodes


def parse_expression(src):
  """Returns the AST of given identifier.

  Args:
    src: A piece of code that represents a single Python expression
  Returns:
    A gast.AST object.
  Raises:
    ValueError: if src does not consist of a single Expression.
  """
  src = STANDARD_PREAMBLE + src.strip()
  node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
  if __debug__:
    if not isinstance(node, gast.Expr):
      raise ValueError(
          'expected exactly one node of type Expr, got {}'.format(node))
  return node.value


def unparse(node, indentation=None, include_encoding_marker=True):
  """Returns the source code of given AST.

  Args:
    node: The code to compile, as an AST object.
    indentation: Unused, deprecated. The returning code will always be indented
      at 4 spaces.
    include_encoding_marker: Bool, whether to include a comment on the first
      line to explicitly specify UTF-8 encoding.

  Returns:
    code: The source code generated from the AST object
    source_mapping: A mapping between the user and AutoGraph generated code.
  """
  del indentation  # astunparse doesn't allow configuring it.
  if not isinstance(node, (list, tuple)):
    node = (node,)

  codes = []
  if include_encoding_marker:
    codes.append('# coding=utf-8')
  for n in node:
    if isinstance(n, gast.AST):
      ast_n = gast.gast_to_ast(n)
    else:
      ast_n = n

    if astunparse is ast:
      ast.fix_missing_locations(ast_n)  # Only ast needs to call this.
    codes.append(astunparse.unparse(ast_n).strip())

  return '\n'.join(codes)
