Use a slower, but more robust method to parse the source of lambda functions. Fixes #39832.

This method is 10-30x slower in benchmarks, and scales poorly with file size. However, it can be sped up by moving processing to C++ or by detecting safe cases by looking at the surrounding code.

Unlike the regular approach from inspect which relied on a coarse mechanism based on line numbers alone, and which used brittle regexp / lexer searches, this approach parses the entire source file, and then searches for matching ast.Lambda nodes based on line number information and argument signature. The mechanism errors out when this resolution cannot be done precisely (e.g. `lamda x: lambda x: x`). This in turn guarantees that the parsed AST matches exactly what the interpreter sees.

A more permanent fix would address the issue in `inspect` which makes an incorrect assumption here: 53d2b715d1/Lib/inspect.py (L924)

PiperOrigin-RevId: 319972026
Change-Id: Ie3f522486860f512add9409c246a13b87462f4aa
This commit is contained in:
Dan Moldovan 2020-07-07 06:34:49 -07:00 committed by TensorFlower Gardener
parent d5c4d24f72
commit f621c65020
9 changed files with 390 additions and 199 deletions

View File

@ -645,6 +645,39 @@ to quickly diagnose whether the source code is available for a function.
#### Source code of lambda functions
##### Changes in TF 2.4
Key Point: When nesting lambda functions, use distinguishing argument names
to avoid parse errors.
The Python runtime exposes the source code of lambda functions, however it
may omit parts of the actual body, or include surrounding code. This may make it
impossible to parse the exact source code of the lambda function (see
https://github.com/tensorflow/tensorflow/issues/39832).
AutoGraph uses alternate methods to parse the source code more robustly, but
in rare cases it may be unable to distinguish between nested lambda functions
of identical signatures.
Example:
```
l = lambda x: lambda x: x + 1
```
AutoGraph raises an error for the code above because the parser cannot
distinguish between the two function signatures. To work around this limitation,
use distinct argument names:
```
l = lambda outer_x: lambda inner_x: inner_x + 1
```
##### TF 2.3 and older
In older versions of TensorFlow, the loading code for lambda functions is not
robust. Follow the guidance below to avoid errors.
Important: Declare lambda functions on single lines to make sure their source
code loads correctly.

View File

@ -25,7 +25,6 @@ import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.util import tf_inspect
class CleanCopier(object):
@ -347,54 +346,3 @@ def parallel_walk(node, other):
raise ValueError(
'inconsistent values for field {}: {} and {}'.format(
f, n_child, o_child))
class LambdaDefinitionMatcher(gast.NodeVisitor):
"""Finds lambda nodes that match a given lambda's signature."""
def __init__(self, fn):
self.fn = fn
self.matching_nodes = []
def _arg_name(self, node):
if node is None:
return None
if isinstance(node, gast.Name):
return node.id
assert isinstance(node, str)
return node
def _argspec_matches(self, node):
arg_spec = tf_inspect.getfullargspec(self.fn)
node_args = tuple(self._arg_name(arg) for arg in node.args.args)
if node_args != tuple(arg_spec.args):
return False
if arg_spec.varargs != self._arg_name(node.args.vararg):
return False
if arg_spec.varkw != self._arg_name(node.args.kwarg):
return False
node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs)
if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
return False
return True
def visit_Lambda(self, node):
self.generic_visit(node)
if self.fn.__name__ != '<lambda>':
return
if not self._argspec_matches(node):
return
self.matching_nodes.append(node)
def find_matching_definitions(node, f):
matcher = LambdaDefinitionMatcher(f)
matcher.visit(node)
return tuple(matcher.matching_nodes)

View File

@ -235,37 +235,6 @@ class AstUtilTest(test.TestCase):
parser.unparse(node.body, include_encoding_marker=False).strip(),
expected_bodies)
def test_find_matching_definitions_lambda(self):
node = parser.parse(
textwrap.dedent("""
f = lambda x: 1
"""))
f = lambda x: x
nodes = ast_util.find_matching_definitions(node, f)
self.assertLambdaNodes(nodes, ('1',))
def test_find_matching_definitions_lambda_multiple_matches(self):
node = parser.parse(
textwrap.dedent("""
f = lambda x: 1, lambda x: 2
"""))
f = lambda x: x
nodes = ast_util.find_matching_definitions(node, f)
self.assertLambdaNodes(nodes, ('1', '2'))
def test_find_matching_definitions_lambda_uses_arg_names(self):
node = parser.parse(
textwrap.dedent("""
f = lambda x: 1, lambda y: 2
"""))
f = lambda x: x
nodes = ast_util.find_matching_definitions(node, f)
self.assertLambdaNodes(nodes, ('1',))
f = lambda y: y
nodes = ast_util.find_matching_definitions(node, f)
self.assertLambdaNodes(nodes, ('2',))
if __name__ == '__main__':
test.main()

View File

@ -247,11 +247,19 @@ def resolve(node, source, context_filepath, context_lineno, context_col_offset):
# TODO(mdan): Pull this to a separate utility.
code_reader = six.StringIO(source)
comments_map = {}
for token in tokenize.generate_tokens(code_reader.readline):
tok_type, tok_string, loc, _, _ = token
srow, _ = loc
if tok_type == tokenize.COMMENT:
comments_map[srow] = tok_string.strip()[1:].strip()
try:
for token in tokenize.generate_tokens(code_reader.readline):
tok_type, tok_string, loc, _, _ = token
srow, _ = loc
if tok_type == tokenize.COMMENT:
comments_map[srow] = tok_string.strip()[1:].strip()
except tokenize.TokenError:
if isinstance(node, gast.Lambda):
# Source code resolution in older Python versions is brittle for
# lambda functions, and may contain garbage.
pass
else:
raise
source_lines = source.split('\n')
visitor = OriginResolver(node, source_lines, comments_map,

View File

@ -146,6 +146,19 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(ret_origin.source_code_line, ' return x # comment')
self.assertEqual(ret_origin.comment, 'comment')
def test_resolve_with_trailing_garbage(self):
# This comment will be missed because the tokenizer fails to reach it.
source = ' lambda: foo([], bar=1)), baz=2)()'
clean_source = 'lambda: foo([], bar=1)'
node = parser.parse(clean_source).value
origin_info.resolve(node, source, 'test_file', 10, 10)
def_origin = anno.getanno(node, anno.Basic.ORIGIN)
self.assertEqual(def_origin.loc.lineno, 10)
self.assertEqual(def_origin.loc.col_offset, 10)
self.assertEqual(def_origin.source_code_line, source)
self.assertIsNone(def_origin.comment)
def test_resolve_entity(self):
test_fn = basic_definitions.simple_function
node, source = parser.parse_entity(

View File

@ -21,6 +21,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import re
import sys
import textwrap
@ -32,6 +33,7 @@ import six
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.util import tf_inspect
PY2_PREAMBLE = textwrap.dedent("""
@ -39,11 +41,14 @@ from __future__ import division
from __future__ import print_function
""")
PY3_PREAMBLE = ''
MAX_SIZE = 0
if sys.version_info >= (3,):
STANDARD_PREAMBLE = PY3_PREAMBLE
MAX_SIZE = sys.maxsize
else:
STANDARD_PREAMBLE = PY2_PREAMBLE
MAX_SIZE = sys.maxint
STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
@ -124,69 +129,6 @@ def dedent_block(code_string):
return new_code
def _attempt_to_parse_normal_source(source, future_features):
return parse(source, preamble_len=len(future_features)), source
def _attempt_to_parse_lambda_source(source, original_source,
future_features, try_fallback=True):
"""Parsing function specialized on dealing with lambdas.
Lambda functions, only hold the raw code lines which defined
them, which may include surrounding tokens and may be syntactically
invalid out of context. For example:
l = (
lambda x: x,)[0]
will have the dedented source "lambda x: x,)[0]"
This function makes an attempt to stip away the garbage by looking at the
information in the syntax error.
Args:
source: the processed source code of `entity`.
original_source: the source code of `entity`, as it was reported
by `inspect.getsource`.
future_features: see `parse`.
try_fallback: whether to attempt to remove extra code from `source` before
one more attempt to parse it.
Returns:
Same as `parse`.
"""
try:
return parse(source, preamble_len=len(future_features)), source
# Note: the ValueError may be raised by parse.
except (SyntaxError, ValueError) as e:
def fail():
raise errors.UnsupportedLanguageElementError(
'could not parse the source code:'
'\n\n{}\n'
'This error may be avoided by creating the lambda in a standalone'
' statement.\n'.format(original_source))
if not try_fallback:
fail()
lines = source.split('\n')
lineno, offset = e.lineno, e.offset # 1-based
# Give up if there's nothing we can chip away.
if len(lines) == lineno and len(lines[-1]) == offset:
fail()
# Drop all lines following the error location
# TODO(mdan): What's with the pylint errors?
lines = lines[:lineno] # pylint:disable=invalid-slice-index
# Drop all characters following the error location
lines[-1] = lines[-1][:offset - 1] # pylint:disable=invalid-slice-index
source = '\n'.join(lines)
return _attempt_to_parse_lambda_source(
source, original_source, future_features, try_fallback=False)
def parse_entity(entity, future_features):
"""Returns the AST and source code of given entity.
@ -200,6 +142,9 @@ def parse_entity(entity, future_features):
gast.AST, Text: the parsed AST node; the source code that was parsed to
generate the AST (including any prefixes that this function may have added).
"""
if inspect_utils.islambda(entity):
return _parse_lambda(entity)
try:
original_source = inspect_utils.getimmediatesource(entity)
except (IOError, OSError) as e:
@ -217,11 +162,155 @@ def parse_entity(entity, future_features):
'from __future__ import {}'.format(name) for name in future_features)
source = '\n'.join(future_statements + (source,))
if inspect_utils.islambda(entity):
return _attempt_to_parse_lambda_source(
source, original_source, future_features)
else:
return _attempt_to_parse_normal_source(source, future_features)
return parse(source, preamble_len=len(future_features)), source
def _without_context(node, lines, minl, maxl):
"""Returns a clean node and source code without indenting and context."""
for n in gast.walk(node):
lineno = getattr(n, 'lineno', None)
if lineno is not None:
n.lineno = lineno - minl
end_lineno = getattr(n, 'end_lineno', None)
if end_lineno is not None:
n.end_lineno = end_lineno - minl
code_lines = lines[minl - 1:maxl]
# Attempt to clean up surrounding context code.
end_col_offset = getattr(node, 'end_col_offset', None)
if end_col_offset is not None:
# This is only available in 3.8.
code_lines[-1] = code_lines[-1][:end_col_offset]
code_block = '\n'.join(lines[minl - 1:maxl])
col_offset = getattr(node, 'col_offset', None)
if col_offset is None:
# Older Python: try to find the "lambda" token. This is brittle.
match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
if match is not None:
col_offset = match.start(0)
if col_offset is not None:
code_lines[0] = code_lines[0][col_offset:]
code_block = '\n'.join(code_lines)
return node, code_block
def _arg_name(node):
if node is None:
return None
if isinstance(node, gast.Name):
return node.id
assert isinstance(node, str)
return node
def _node_matches_argspec(node, func):
"""Returns True is node fits the argspec of func."""
# TODO(mdan): Use just inspect once support for Python 2 is dropped.
arg_spec = tf_inspect.getfullargspec(func)
node_args = tuple(_arg_name(arg) for arg in node.args.args)
if node_args != tuple(arg_spec.args):
return False
if arg_spec.varargs != _arg_name(node.args.vararg):
return False
if arg_spec.varkw != _arg_name(node.args.kwarg):
return False
node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
return False
return True
def _parse_lambda(lam):
"""Returns the AST and source code of given lambda function.
Args:
lam: types.LambdaType, Python function/method/class
Returns:
gast.AST, Text: the parsed AST node; the source code that was parsed to
generate the AST (including any prefixes that this function may have added).
"""
# TODO(mdan): Use a fast path if the definition is not multi-line.
# We could detect that the lambda is in a multi-line expression by looking
# at the surrounding code - an surrounding set of parentheses indicates a
# potential multi-line definition.
mod = inspect.getmodule(lam)
def_line = lam.__code__.co_firstlineno
source = inspect.getsource(mod)
lines = source.split('\n')
# Narrow down to the last node starting before our definition node.
all_nodes = parse(source, preamble_len=0, single_node=False)
search_nodes = []
for node in all_nodes:
# Also include nodes without a line number, for safety. This is defensive -
# we don't know whether such nodes might exist, and if they do, whether
# they are not safe to skip.
# TODO(mdan): Replace this check with an assertion or skip such nodes.
if getattr(node, 'lineno', def_line) <= def_line:
search_nodes.append(node)
else:
# Found a node starting past our lambda - can stop the search.
break
# Extract all lambda nodes from the shortlist.
lambda_nodes = []
for node in search_nodes:
lambda_nodes.extend(
n for n in gast.walk(node) if isinstance(n, gast.Lambda))
# Filter down to lambda nodes which span our actual lambda.
candidates = []
for ln in lambda_nodes:
minl, maxl = MAX_SIZE, 0
for n in gast.walk(ln):
minl = min(minl, getattr(n, 'lineno', minl))
lineno = getattr(n, 'lineno', maxl)
end_lineno = getattr(n, 'end_lineno', None)
if end_lineno is not None:
# end_lineno is more precise, but lineno should almost always work too.
lineno = end_lineno
maxl = max(maxl, lineno)
if minl <= def_line <= maxl:
candidates.append((ln, minl, maxl))
# Happy path: exactly one node found.
if len(candidates) == 1:
(node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking
return _without_context(node, lines, minl, maxl)
elif not candidates:
raise errors.UnsupportedLanguageElementError(
'could not parse the source code of {}:'
' no matching AST found'.format(lam))
# Attempt to narrow down selection by signature is multiple nodes are found.
matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
if len(matches) == 1:
(node, minl, maxl), = matches
return _without_context(node, lines, minl, maxl)
# Give up if could not narrow down to a single node.
matches = '\n'.join(
'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
for i, (node, _, _) in enumerate(matches))
raise errors.UnsupportedLanguageElementError(
'could not parse the source code of {}: found multiple definitions with'
' identical signatures at the location. This error'
' may be avoided by defining each lambda on a single line and with'
' unique argument names.\n{}'.format(lam, matches))
# TODO(mdan): This should take futures as input instead.

View File

@ -18,14 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import textwrap
import gast
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
# Version notice: these tests pass in Python 3.7. They will fail in 3.8, where
# the parser is able to clean up the trailing garbage.
# TODO(mdan): Update the tests to work in 3.8 as well.
class ParserTest(test.TestCase):
def test_parse_entity(self):
@ -36,6 +43,171 @@ class ParserTest(test.TestCase):
node, _ = parser.parse_entity(f, future_features=())
self.assertEqual('f', node.name)
def test_parse_lambda(self):
l = lambda x: x + 1
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertEqual(source, 'lambda x: x + 1')
def test_parse_lambda_prefix_cleanup(self):
lambda_lam = lambda x: x + 1
node, source = parser.parse_entity(lambda_lam, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertEqual(source, 'lambda x: x + 1')
def test_parse_lambda_resolution_by_location(self):
_ = lambda x: x + 1
l = lambda x: x + 1
_ = lambda x: x + 1
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (x + 1))')
self.assertEqual(source, 'lambda x: x + 1')
def test_parse_lambda_resolution_by_signature(self):
l = lambda x: lambda x, y: x + y
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda x, y: (x + y)))')
self.assertEqual(source, 'lambda x: lambda x, y: x + y')
node, source = parser.parse_entity(l(0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x, y: (x + y))')
self.assertEqual(source, 'lambda x, y: x + y')
def test_parse_lambda_resolution_ambiguous(self):
l = lambda x: lambda x: 2 * x
expected_exception_text = re.compile(r'found multiple definitions'
r'.+'
r'\(lambda x: \(lambda x'
r'.+'
r'\(lambda x: \(2', re.DOTALL)
with self.assertRaisesRegex(
errors.UnsupportedLanguageElementError,
expected_exception_text):
parser.parse_entity(l, future_features=())
with self.assertRaisesRegex(
errors.UnsupportedLanguageElementError,
expected_exception_text):
parser.parse_entity(l(0), future_features=())
def test_parse_lambda_multiline(self):
l = (
lambda x: lambda y: x + y # pylint:disable=g-long-lambda
- 1)
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) - 1)))')
self.assertEqual(
source,
'lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n'
' - 1)')
node, source = parser.parse_entity(l(0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) - 1))')
self.assertEqual(
source,
'lambda y: x + y # pylint:disable=g-long-lambda\n'
' - 1)')
def test_parse_lambda_in_expression(self):
l = (
lambda x: lambda y: x + y + 1,
lambda x: lambda y: x + y + 2,
)
node, source = parser.parse_entity(l[0], future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) + 1)))')
self.assertEqual(source, 'lambda x: lambda y: x + y + 1,')
node, source = parser.parse_entity(l[0](0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) + 1))')
self.assertEqual(source, 'lambda y: x + y + 1,')
node, source = parser.parse_entity(l[1], future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: (lambda y: ((x + y) + 2)))')
self.assertEqual(source, 'lambda x: lambda y: x + y + 2,')
node, source = parser.parse_entity(l[1](0), future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda y: ((x + y) + 2))')
self.assertEqual(source, 'lambda y: x + y + 2,')
def test_parse_lambda_complex_body(self):
l = lambda x: ( # pylint:disable=g-long-lambda
x.y(
[],
x.z,
(),
x[0:2],
),
x.u,
'abc',
1,
)
node, source = parser.parse_entity(l, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
"(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
self.assertEqual(source, ('lambda x: ( # pylint:disable=g-long-lambda\n'
' x.y(\n'
' [],\n'
' x.z,\n'
' (),\n'
' x[0:2],\n'
' ),\n'
' x.u,\n'
' \'abc\',\n'
' 1,'))
def test_parse_lambda_function_call_definition(self):
def do_parse_and_test(lam, **unused_kwargs):
node, source = parser.parse_entity(lam, future_features=())
self.assertEqual(
parser.unparse(node, include_encoding_marker=False),
'(lambda x: x)')
self.assertEqual(source, 'lambda x: x, named_arg=1)')
do_parse_and_test( # Intentional line break
lambda x: x, named_arg=1)
def test_parse_entity_print_function(self):
def f(x):
@ -47,7 +219,7 @@ class ParserTest(test.TestCase):
def test_parse_comments(self):
def f():
# unindented comment
# unindented comment
pass
node, _ = parser.parse_entity(f, future_features=())

View File

@ -23,7 +23,6 @@ import types
import gast
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
@ -303,25 +302,6 @@ class FunctionTranspiler(object):
node, source = parser.parse_entity(fn, future_features=future_features)
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
# In general, the output of inspect.getsource is inexact for lambdas
# because it uses regex matching to adjust the exact location around
# the line number that CPython records. Then, the entire containing line
# is returned, which we may have trouble disambiguating.
# For example:
# x, y = lambda: 1, lambda: 2
is_lambda = fn.__name__ == '<lambda>'
if is_lambda:
nodes = ast_util.find_matching_definitions(node, fn)
if len(nodes) != 1:
raise ValueError(
'Unable to identify source code of lambda function {}.'
' It was defined in this code:\n'
'{}\n'
'This code must contain a single distinguishable lambda.'
' To avoid this problem, define each lambda in a separate'
' expression.'.format(fn, source))
node, = nodes
origin_info.resolve_entity(node, source, fn)
namespace = inspect_utils.getnamespace(fn)
@ -338,7 +318,7 @@ class FunctionTranspiler(object):
node = self._erase_arg_defaults(node)
node = self.transform_ast(node, context)
if is_lambda:
if isinstance(node, gast.Lambda):
node = gast.Assign(
targets=[
gast.Name(

View File

@ -136,27 +136,6 @@ class FunctionTranspilerTest(test.TestCase):
self.assertEqual(f(1), 1 - 1)
def test_multiple_lambdas_indistinguishable_definitions(self):
a, b = 1, 2
f, _ = (lambda x: a * x, lambda x: b * x)
tr = TestTranspiler()
with self.assertRaises(ValueError):
tr.transform_function(f, object(), None, {})
def test_lambda_code_with_removable_garbage(self):
# pylint:disable=g-long-lambda
f = ( # intentional wrap
lambda x: (
x # intentional wrap
+ 1),)[0]
# pylint:enable=g-long-lambda
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 1 - 1)
def test_nested_functions(self):
b = 2