Use a slower, but more robust method to parse the source of lambda functions. Fixes #39832.
This method is 10-30x slower in benchmarks, and scales poorly with file size. However, it can be sped up by moving processing to C++ or by detecting safe cases by looking at the surrounding code.
Unlike the regular approach from inspect which relied on a coarse mechanism based on line numbers alone, and which used brittle regexp / lexer searches, this approach parses the entire source file, and then searches for matching ast.Lambda nodes based on line number information and argument signature. The mechanism errors out when this resolution cannot be done precisely (e.g. `lamda x: lambda x: x`). This in turn guarantees that the parsed AST matches exactly what the interpreter sees.
A more permanent fix would address the issue in `inspect` which makes an incorrect assumption here: 53d2b715d1/Lib/inspect.py (L924)
PiperOrigin-RevId: 319972026
Change-Id: Ie3f522486860f512add9409c246a13b87462f4aa
This commit is contained in:
parent
d5c4d24f72
commit
f621c65020
@ -645,6 +645,39 @@ to quickly diagnose whether the source code is available for a function.
|
||||
|
||||
#### Source code of lambda functions
|
||||
|
||||
##### Changes in TF 2.4
|
||||
|
||||
Key Point: When nesting lambda functions, use distinguishing argument names
|
||||
to avoid parse errors.
|
||||
|
||||
The Python runtime exposes the source code of lambda functions, however it
|
||||
may omit parts of the actual body, or include surrounding code. This may make it
|
||||
impossible to parse the exact source code of the lambda function (see
|
||||
https://github.com/tensorflow/tensorflow/issues/39832).
|
||||
|
||||
AutoGraph uses alternate methods to parse the source code more robustly, but
|
||||
in rare cases it may be unable to distinguish between nested lambda functions
|
||||
of identical signatures.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
l = lambda x: lambda x: x + 1
|
||||
```
|
||||
|
||||
AutoGraph raises an error for the code above because the parser cannot
|
||||
distinguish between the two function signatures. To work around this limitation,
|
||||
use distinct argument names:
|
||||
|
||||
```
|
||||
l = lambda outer_x: lambda inner_x: inner_x + 1
|
||||
```
|
||||
|
||||
##### TF 2.3 and older
|
||||
|
||||
In older versions of TensorFlow, the loading code for lambda functions is not
|
||||
robust. Follow the guidance below to avoid errors.
|
||||
|
||||
Important: Declare lambda functions on single lines to make sure their source
|
||||
code loads correctly.
|
||||
|
||||
|
@ -25,7 +25,6 @@ import gast
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class CleanCopier(object):
|
||||
@ -347,54 +346,3 @@ def parallel_walk(node, other):
|
||||
raise ValueError(
|
||||
'inconsistent values for field {}: {} and {}'.format(
|
||||
f, n_child, o_child))
|
||||
|
||||
|
||||
class LambdaDefinitionMatcher(gast.NodeVisitor):
|
||||
"""Finds lambda nodes that match a given lambda's signature."""
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
self.matching_nodes = []
|
||||
|
||||
def _arg_name(self, node):
|
||||
if node is None:
|
||||
return None
|
||||
if isinstance(node, gast.Name):
|
||||
return node.id
|
||||
assert isinstance(node, str)
|
||||
return node
|
||||
|
||||
def _argspec_matches(self, node):
|
||||
arg_spec = tf_inspect.getfullargspec(self.fn)
|
||||
|
||||
node_args = tuple(self._arg_name(arg) for arg in node.args.args)
|
||||
if node_args != tuple(arg_spec.args):
|
||||
return False
|
||||
|
||||
if arg_spec.varargs != self._arg_name(node.args.vararg):
|
||||
return False
|
||||
|
||||
if arg_spec.varkw != self._arg_name(node.args.kwarg):
|
||||
return False
|
||||
|
||||
node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs)
|
||||
if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
self.generic_visit(node)
|
||||
|
||||
if self.fn.__name__ != '<lambda>':
|
||||
return
|
||||
if not self._argspec_matches(node):
|
||||
return
|
||||
|
||||
self.matching_nodes.append(node)
|
||||
|
||||
|
||||
def find_matching_definitions(node, f):
|
||||
matcher = LambdaDefinitionMatcher(f)
|
||||
matcher.visit(node)
|
||||
return tuple(matcher.matching_nodes)
|
||||
|
@ -235,37 +235,6 @@ class AstUtilTest(test.TestCase):
|
||||
parser.unparse(node.body, include_encoding_marker=False).strip(),
|
||||
expected_bodies)
|
||||
|
||||
def test_find_matching_definitions_lambda(self):
|
||||
node = parser.parse(
|
||||
textwrap.dedent("""
|
||||
f = lambda x: 1
|
||||
"""))
|
||||
f = lambda x: x
|
||||
nodes = ast_util.find_matching_definitions(node, f)
|
||||
self.assertLambdaNodes(nodes, ('1',))
|
||||
|
||||
def test_find_matching_definitions_lambda_multiple_matches(self):
|
||||
node = parser.parse(
|
||||
textwrap.dedent("""
|
||||
f = lambda x: 1, lambda x: 2
|
||||
"""))
|
||||
f = lambda x: x
|
||||
nodes = ast_util.find_matching_definitions(node, f)
|
||||
self.assertLambdaNodes(nodes, ('1', '2'))
|
||||
|
||||
def test_find_matching_definitions_lambda_uses_arg_names(self):
|
||||
node = parser.parse(
|
||||
textwrap.dedent("""
|
||||
f = lambda x: 1, lambda y: 2
|
||||
"""))
|
||||
f = lambda x: x
|
||||
nodes = ast_util.find_matching_definitions(node, f)
|
||||
self.assertLambdaNodes(nodes, ('1',))
|
||||
|
||||
f = lambda y: y
|
||||
nodes = ast_util.find_matching_definitions(node, f)
|
||||
self.assertLambdaNodes(nodes, ('2',))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -247,11 +247,19 @@ def resolve(node, source, context_filepath, context_lineno, context_col_offset):
|
||||
# TODO(mdan): Pull this to a separate utility.
|
||||
code_reader = six.StringIO(source)
|
||||
comments_map = {}
|
||||
for token in tokenize.generate_tokens(code_reader.readline):
|
||||
tok_type, tok_string, loc, _, _ = token
|
||||
srow, _ = loc
|
||||
if tok_type == tokenize.COMMENT:
|
||||
comments_map[srow] = tok_string.strip()[1:].strip()
|
||||
try:
|
||||
for token in tokenize.generate_tokens(code_reader.readline):
|
||||
tok_type, tok_string, loc, _, _ = token
|
||||
srow, _ = loc
|
||||
if tok_type == tokenize.COMMENT:
|
||||
comments_map[srow] = tok_string.strip()[1:].strip()
|
||||
except tokenize.TokenError:
|
||||
if isinstance(node, gast.Lambda):
|
||||
# Source code resolution in older Python versions is brittle for
|
||||
# lambda functions, and may contain garbage.
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
source_lines = source.split('\n')
|
||||
visitor = OriginResolver(node, source_lines, comments_map,
|
||||
|
@ -146,6 +146,19 @@ class OriginInfoTest(test.TestCase):
|
||||
self.assertEqual(ret_origin.source_code_line, ' return x # comment')
|
||||
self.assertEqual(ret_origin.comment, 'comment')
|
||||
|
||||
def test_resolve_with_trailing_garbage(self):
|
||||
# This comment will be missed because the tokenizer fails to reach it.
|
||||
source = ' lambda: foo([], bar=1)), baz=2)()'
|
||||
clean_source = 'lambda: foo([], bar=1)'
|
||||
node = parser.parse(clean_source).value
|
||||
origin_info.resolve(node, source, 'test_file', 10, 10)
|
||||
|
||||
def_origin = anno.getanno(node, anno.Basic.ORIGIN)
|
||||
self.assertEqual(def_origin.loc.lineno, 10)
|
||||
self.assertEqual(def_origin.loc.col_offset, 10)
|
||||
self.assertEqual(def_origin.source_code_line, source)
|
||||
self.assertIsNone(def_origin.comment)
|
||||
|
||||
def test_resolve_entity(self):
|
||||
test_fn = basic_definitions.simple_function
|
||||
node, source = parser.parse_entity(
|
||||
|
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
@ -32,6 +33,7 @@ import six
|
||||
|
||||
from tensorflow.python.autograph.pyct import errors
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
PY2_PREAMBLE = textwrap.dedent("""
|
||||
@ -39,11 +41,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
""")
|
||||
PY3_PREAMBLE = ''
|
||||
MAX_SIZE = 0
|
||||
|
||||
if sys.version_info >= (3,):
|
||||
STANDARD_PREAMBLE = PY3_PREAMBLE
|
||||
MAX_SIZE = sys.maxsize
|
||||
else:
|
||||
STANDARD_PREAMBLE = PY2_PREAMBLE
|
||||
MAX_SIZE = sys.maxint
|
||||
|
||||
STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
|
||||
|
||||
@ -124,69 +129,6 @@ def dedent_block(code_string):
|
||||
return new_code
|
||||
|
||||
|
||||
def _attempt_to_parse_normal_source(source, future_features):
|
||||
return parse(source, preamble_len=len(future_features)), source
|
||||
|
||||
|
||||
def _attempt_to_parse_lambda_source(source, original_source,
|
||||
future_features, try_fallback=True):
|
||||
"""Parsing function specialized on dealing with lambdas.
|
||||
|
||||
Lambda functions, only hold the raw code lines which defined
|
||||
them, which may include surrounding tokens and may be syntactically
|
||||
invalid out of context. For example:
|
||||
|
||||
l = (
|
||||
lambda x: x,)[0]
|
||||
|
||||
will have the dedented source "lambda x: x,)[0]"
|
||||
This function makes an attempt to stip away the garbage by looking at the
|
||||
information in the syntax error.
|
||||
|
||||
Args:
|
||||
source: the processed source code of `entity`.
|
||||
original_source: the source code of `entity`, as it was reported
|
||||
by `inspect.getsource`.
|
||||
future_features: see `parse`.
|
||||
try_fallback: whether to attempt to remove extra code from `source` before
|
||||
one more attempt to parse it.
|
||||
Returns:
|
||||
Same as `parse`.
|
||||
"""
|
||||
|
||||
try:
|
||||
return parse(source, preamble_len=len(future_features)), source
|
||||
|
||||
# Note: the ValueError may be raised by parse.
|
||||
except (SyntaxError, ValueError) as e:
|
||||
def fail():
|
||||
raise errors.UnsupportedLanguageElementError(
|
||||
'could not parse the source code:'
|
||||
'\n\n{}\n'
|
||||
'This error may be avoided by creating the lambda in a standalone'
|
||||
' statement.\n'.format(original_source))
|
||||
|
||||
if not try_fallback:
|
||||
fail()
|
||||
|
||||
lines = source.split('\n')
|
||||
lineno, offset = e.lineno, e.offset # 1-based
|
||||
|
||||
# Give up if there's nothing we can chip away.
|
||||
if len(lines) == lineno and len(lines[-1]) == offset:
|
||||
fail()
|
||||
|
||||
# Drop all lines following the error location
|
||||
# TODO(mdan): What's with the pylint errors?
|
||||
lines = lines[:lineno] # pylint:disable=invalid-slice-index
|
||||
# Drop all characters following the error location
|
||||
lines[-1] = lines[-1][:offset - 1] # pylint:disable=invalid-slice-index
|
||||
source = '\n'.join(lines)
|
||||
|
||||
return _attempt_to_parse_lambda_source(
|
||||
source, original_source, future_features, try_fallback=False)
|
||||
|
||||
|
||||
def parse_entity(entity, future_features):
|
||||
"""Returns the AST and source code of given entity.
|
||||
|
||||
@ -200,6 +142,9 @@ def parse_entity(entity, future_features):
|
||||
gast.AST, Text: the parsed AST node; the source code that was parsed to
|
||||
generate the AST (including any prefixes that this function may have added).
|
||||
"""
|
||||
if inspect_utils.islambda(entity):
|
||||
return _parse_lambda(entity)
|
||||
|
||||
try:
|
||||
original_source = inspect_utils.getimmediatesource(entity)
|
||||
except (IOError, OSError) as e:
|
||||
@ -217,11 +162,155 @@ def parse_entity(entity, future_features):
|
||||
'from __future__ import {}'.format(name) for name in future_features)
|
||||
source = '\n'.join(future_statements + (source,))
|
||||
|
||||
if inspect_utils.islambda(entity):
|
||||
return _attempt_to_parse_lambda_source(
|
||||
source, original_source, future_features)
|
||||
else:
|
||||
return _attempt_to_parse_normal_source(source, future_features)
|
||||
return parse(source, preamble_len=len(future_features)), source
|
||||
|
||||
|
||||
def _without_context(node, lines, minl, maxl):
|
||||
"""Returns a clean node and source code without indenting and context."""
|
||||
for n in gast.walk(node):
|
||||
lineno = getattr(n, 'lineno', None)
|
||||
if lineno is not None:
|
||||
n.lineno = lineno - minl
|
||||
end_lineno = getattr(n, 'end_lineno', None)
|
||||
if end_lineno is not None:
|
||||
n.end_lineno = end_lineno - minl
|
||||
|
||||
code_lines = lines[minl - 1:maxl]
|
||||
|
||||
# Attempt to clean up surrounding context code.
|
||||
|
||||
end_col_offset = getattr(node, 'end_col_offset', None)
|
||||
if end_col_offset is not None:
|
||||
# This is only available in 3.8.
|
||||
code_lines[-1] = code_lines[-1][:end_col_offset]
|
||||
code_block = '\n'.join(lines[minl - 1:maxl])
|
||||
|
||||
col_offset = getattr(node, 'col_offset', None)
|
||||
if col_offset is None:
|
||||
# Older Python: try to find the "lambda" token. This is brittle.
|
||||
match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
|
||||
if match is not None:
|
||||
col_offset = match.start(0)
|
||||
|
||||
if col_offset is not None:
|
||||
code_lines[0] = code_lines[0][col_offset:]
|
||||
|
||||
code_block = '\n'.join(code_lines)
|
||||
|
||||
return node, code_block
|
||||
|
||||
|
||||
def _arg_name(node):
|
||||
if node is None:
|
||||
return None
|
||||
if isinstance(node, gast.Name):
|
||||
return node.id
|
||||
assert isinstance(node, str)
|
||||
return node
|
||||
|
||||
|
||||
def _node_matches_argspec(node, func):
|
||||
"""Returns True is node fits the argspec of func."""
|
||||
# TODO(mdan): Use just inspect once support for Python 2 is dropped.
|
||||
arg_spec = tf_inspect.getfullargspec(func)
|
||||
|
||||
node_args = tuple(_arg_name(arg) for arg in node.args.args)
|
||||
if node_args != tuple(arg_spec.args):
|
||||
return False
|
||||
|
||||
if arg_spec.varargs != _arg_name(node.args.vararg):
|
||||
return False
|
||||
|
||||
if arg_spec.varkw != _arg_name(node.args.kwarg):
|
||||
return False
|
||||
|
||||
node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
|
||||
if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _parse_lambda(lam):
|
||||
"""Returns the AST and source code of given lambda function.
|
||||
|
||||
Args:
|
||||
lam: types.LambdaType, Python function/method/class
|
||||
|
||||
Returns:
|
||||
gast.AST, Text: the parsed AST node; the source code that was parsed to
|
||||
generate the AST (including any prefixes that this function may have added).
|
||||
"""
|
||||
# TODO(mdan): Use a fast path if the definition is not multi-line.
|
||||
# We could detect that the lambda is in a multi-line expression by looking
|
||||
# at the surrounding code - an surrounding set of parentheses indicates a
|
||||
# potential multi-line definition.
|
||||
|
||||
mod = inspect.getmodule(lam)
|
||||
def_line = lam.__code__.co_firstlineno
|
||||
source = inspect.getsource(mod)
|
||||
lines = source.split('\n')
|
||||
|
||||
# Narrow down to the last node starting before our definition node.
|
||||
all_nodes = parse(source, preamble_len=0, single_node=False)
|
||||
search_nodes = []
|
||||
for node in all_nodes:
|
||||
# Also include nodes without a line number, for safety. This is defensive -
|
||||
# we don't know whether such nodes might exist, and if they do, whether
|
||||
# they are not safe to skip.
|
||||
# TODO(mdan): Replace this check with an assertion or skip such nodes.
|
||||
if getattr(node, 'lineno', def_line) <= def_line:
|
||||
search_nodes.append(node)
|
||||
else:
|
||||
# Found a node starting past our lambda - can stop the search.
|
||||
break
|
||||
|
||||
# Extract all lambda nodes from the shortlist.
|
||||
lambda_nodes = []
|
||||
for node in search_nodes:
|
||||
lambda_nodes.extend(
|
||||
n for n in gast.walk(node) if isinstance(n, gast.Lambda))
|
||||
|
||||
# Filter down to lambda nodes which span our actual lambda.
|
||||
candidates = []
|
||||
for ln in lambda_nodes:
|
||||
minl, maxl = MAX_SIZE, 0
|
||||
for n in gast.walk(ln):
|
||||
minl = min(minl, getattr(n, 'lineno', minl))
|
||||
lineno = getattr(n, 'lineno', maxl)
|
||||
end_lineno = getattr(n, 'end_lineno', None)
|
||||
if end_lineno is not None:
|
||||
# end_lineno is more precise, but lineno should almost always work too.
|
||||
lineno = end_lineno
|
||||
maxl = max(maxl, lineno)
|
||||
if minl <= def_line <= maxl:
|
||||
candidates.append((ln, minl, maxl))
|
||||
|
||||
# Happy path: exactly one node found.
|
||||
if len(candidates) == 1:
|
||||
(node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking
|
||||
return _without_context(node, lines, minl, maxl)
|
||||
|
||||
elif not candidates:
|
||||
raise errors.UnsupportedLanguageElementError(
|
||||
'could not parse the source code of {}:'
|
||||
' no matching AST found'.format(lam))
|
||||
|
||||
# Attempt to narrow down selection by signature is multiple nodes are found.
|
||||
matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
|
||||
if len(matches) == 1:
|
||||
(node, minl, maxl), = matches
|
||||
return _without_context(node, lines, minl, maxl)
|
||||
|
||||
# Give up if could not narrow down to a single node.
|
||||
matches = '\n'.join(
|
||||
'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
|
||||
for i, (node, _, _) in enumerate(matches))
|
||||
raise errors.UnsupportedLanguageElementError(
|
||||
'could not parse the source code of {}: found multiple definitions with'
|
||||
' identical signatures at the location. This error'
|
||||
' may be avoided by defining each lambda on a single line and with'
|
||||
' unique argument names.\n{}'.format(lam, matches))
|
||||
|
||||
|
||||
# TODO(mdan): This should take futures as input instead.
|
||||
|
@ -18,14 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import errors
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# Version notice: these tests pass in Python 3.7. They will fail in 3.8, where
|
||||
# the parser is able to clean up the trailing garbage.
|
||||
# TODO(mdan): Update the tests to work in 3.8 as well.
|
||||
|
||||
|
||||
class ParserTest(test.TestCase):
|
||||
|
||||
def test_parse_entity(self):
|
||||
@ -36,6 +43,171 @@ class ParserTest(test.TestCase):
|
||||
node, _ = parser.parse_entity(f, future_features=())
|
||||
self.assertEqual('f', node.name)
|
||||
|
||||
def test_parse_lambda(self):
|
||||
|
||||
l = lambda x: x + 1
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
|
||||
def test_parse_lambda_prefix_cleanup(self):
|
||||
|
||||
lambda_lam = lambda x: x + 1
|
||||
|
||||
node, source = parser.parse_entity(lambda_lam, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
|
||||
def test_parse_lambda_resolution_by_location(self):
|
||||
|
||||
_ = lambda x: x + 1
|
||||
l = lambda x: x + 1
|
||||
_ = lambda x: x + 1
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (x + 1))')
|
||||
self.assertEqual(source, 'lambda x: x + 1')
|
||||
|
||||
def test_parse_lambda_resolution_by_signature(self):
|
||||
|
||||
l = lambda x: lambda x, y: x + y
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda x, y: (x + y)))')
|
||||
self.assertEqual(source, 'lambda x: lambda x, y: x + y')
|
||||
|
||||
node, source = parser.parse_entity(l(0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x, y: (x + y))')
|
||||
self.assertEqual(source, 'lambda x, y: x + y')
|
||||
|
||||
def test_parse_lambda_resolution_ambiguous(self):
|
||||
|
||||
l = lambda x: lambda x: 2 * x
|
||||
|
||||
expected_exception_text = re.compile(r'found multiple definitions'
|
||||
r'.+'
|
||||
r'\(lambda x: \(lambda x'
|
||||
r'.+'
|
||||
r'\(lambda x: \(2', re.DOTALL)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.UnsupportedLanguageElementError,
|
||||
expected_exception_text):
|
||||
parser.parse_entity(l, future_features=())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.UnsupportedLanguageElementError,
|
||||
expected_exception_text):
|
||||
parser.parse_entity(l(0), future_features=())
|
||||
|
||||
def test_parse_lambda_multiline(self):
|
||||
|
||||
l = (
|
||||
lambda x: lambda y: x + y # pylint:disable=g-long-lambda
|
||||
- 1)
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) - 1)))')
|
||||
self.assertEqual(
|
||||
source,
|
||||
'lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n'
|
||||
' - 1)')
|
||||
|
||||
node, source = parser.parse_entity(l(0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) - 1))')
|
||||
self.assertEqual(
|
||||
source,
|
||||
'lambda y: x + y # pylint:disable=g-long-lambda\n'
|
||||
' - 1)')
|
||||
|
||||
def test_parse_lambda_in_expression(self):
|
||||
|
||||
l = (
|
||||
lambda x: lambda y: x + y + 1,
|
||||
lambda x: lambda y: x + y + 2,
|
||||
)
|
||||
|
||||
node, source = parser.parse_entity(l[0], future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) + 1)))')
|
||||
self.assertEqual(source, 'lambda x: lambda y: x + y + 1,')
|
||||
|
||||
node, source = parser.parse_entity(l[0](0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) + 1))')
|
||||
self.assertEqual(source, 'lambda y: x + y + 1,')
|
||||
|
||||
node, source = parser.parse_entity(l[1], future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: (lambda y: ((x + y) + 2)))')
|
||||
self.assertEqual(source, 'lambda x: lambda y: x + y + 2,')
|
||||
|
||||
node, source = parser.parse_entity(l[1](0), future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda y: ((x + y) + 2))')
|
||||
self.assertEqual(source, 'lambda y: x + y + 2,')
|
||||
|
||||
def test_parse_lambda_complex_body(self):
|
||||
|
||||
l = lambda x: ( # pylint:disable=g-long-lambda
|
||||
x.y(
|
||||
[],
|
||||
x.z,
|
||||
(),
|
||||
x[0:2],
|
||||
),
|
||||
x.u,
|
||||
'abc',
|
||||
1,
|
||||
)
|
||||
|
||||
node, source = parser.parse_entity(l, future_features=())
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
"(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
|
||||
self.assertEqual(source, ('lambda x: ( # pylint:disable=g-long-lambda\n'
|
||||
' x.y(\n'
|
||||
' [],\n'
|
||||
' x.z,\n'
|
||||
' (),\n'
|
||||
' x[0:2],\n'
|
||||
' ),\n'
|
||||
' x.u,\n'
|
||||
' \'abc\',\n'
|
||||
' 1,'))
|
||||
|
||||
def test_parse_lambda_function_call_definition(self):
|
||||
|
||||
def do_parse_and_test(lam, **unused_kwargs):
|
||||
node, source = parser.parse_entity(lam, future_features=())
|
||||
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False),
|
||||
'(lambda x: x)')
|
||||
self.assertEqual(source, 'lambda x: x, named_arg=1)')
|
||||
|
||||
do_parse_and_test( # Intentional line break
|
||||
lambda x: x, named_arg=1)
|
||||
|
||||
def test_parse_entity_print_function(self):
|
||||
|
||||
def f(x):
|
||||
@ -47,7 +219,7 @@ class ParserTest(test.TestCase):
|
||||
def test_parse_comments(self):
|
||||
|
||||
def f():
|
||||
# unindented comment
|
||||
# unindented comment
|
||||
pass
|
||||
|
||||
node, _ = parser.parse_entity(f, future_features=())
|
||||
|
@ -23,7 +23,6 @@ import types
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
@ -303,25 +302,6 @@ class FunctionTranspiler(object):
|
||||
node, source = parser.parse_entity(fn, future_features=future_features)
|
||||
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
|
||||
|
||||
# In general, the output of inspect.getsource is inexact for lambdas
|
||||
# because it uses regex matching to adjust the exact location around
|
||||
# the line number that CPython records. Then, the entire containing line
|
||||
# is returned, which we may have trouble disambiguating.
|
||||
# For example:
|
||||
# x, y = lambda: 1, lambda: 2
|
||||
is_lambda = fn.__name__ == '<lambda>'
|
||||
if is_lambda:
|
||||
nodes = ast_util.find_matching_definitions(node, fn)
|
||||
if len(nodes) != 1:
|
||||
raise ValueError(
|
||||
'Unable to identify source code of lambda function {}.'
|
||||
' It was defined in this code:\n'
|
||||
'{}\n'
|
||||
'This code must contain a single distinguishable lambda.'
|
||||
' To avoid this problem, define each lambda in a separate'
|
||||
' expression.'.format(fn, source))
|
||||
node, = nodes
|
||||
|
||||
origin_info.resolve_entity(node, source, fn)
|
||||
|
||||
namespace = inspect_utils.getnamespace(fn)
|
||||
@ -338,7 +318,7 @@ class FunctionTranspiler(object):
|
||||
node = self._erase_arg_defaults(node)
|
||||
node = self.transform_ast(node, context)
|
||||
|
||||
if is_lambda:
|
||||
if isinstance(node, gast.Lambda):
|
||||
node = gast.Assign(
|
||||
targets=[
|
||||
gast.Name(
|
||||
|
@ -136,27 +136,6 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
|
||||
self.assertEqual(f(1), 1 - 1)
|
||||
|
||||
def test_multiple_lambdas_indistinguishable_definitions(self):
|
||||
a, b = 1, 2
|
||||
f, _ = (lambda x: a * x, lambda x: b * x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
with self.assertRaises(ValueError):
|
||||
tr.transform_function(f, object(), None, {})
|
||||
|
||||
def test_lambda_code_with_removable_garbage(self):
|
||||
# pylint:disable=g-long-lambda
|
||||
f = ( # intentional wrap
|
||||
lambda x: (
|
||||
x # intentional wrap
|
||||
+ 1),)[0]
|
||||
# pylint:enable=g-long-lambda
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 1 - 1)
|
||||
|
||||
def test_nested_functions(self):
|
||||
b = 2
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user