diff --git a/tensorflow/python/autograph/converters/function_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py index 100a14e4494..cc1cbcb98e4 100644 --- a/tensorflow/python/autograph/converters/function_scopes.py +++ b/tensorflow/python/autograph/converters/function_scopes.py @@ -38,12 +38,13 @@ class FunctionBodyTransformer(converter.Base): def visit_Return(self, node): if node.value is None: return node + node = self.generic_visit(node) return templates.replace( 'return function_context_name.mark_return_value(value)', function_context_name=self.state[_Function].context_name, value=node.value) - def _function_scope_options(self): + def _function_scope_options(self, fn_scope): """Returns the options with which to create function scopes.""" # Top-level function receive the options that were directly requested. # All others receive the options corresponding to a recursive conversion. @@ -51,81 +52,79 @@ class FunctionBodyTransformer(converter.Base): # primarily because the FunctionScope context also creates a # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See # function_wrappers.py. - if self.state[_Function].level == 2: + if fn_scope.level == 2: return self.ctx.program.options return self.ctx.program.options.call_options() def visit_Lambda(self, node): - self.state[_Function].enter() - node = self.generic_visit(node) + with self.state[_Function] as fn_scope: + node = self.generic_visit(node) + + # Only wrap the top-level function. Theoretically, we can and should wrap + # everything, but that can lead to excessive boilerplate when lambdas are + # nested. + # TODO(mdan): Looks more closely for use cases that actually require this. + if fn_scope.level > 2: + return templates.replace_as_expression( + 'ag__.autograph_artifact(l)', l=node) + + scope = anno.getanno(node, anno.Static.SCOPE) + function_context_name = self.ctx.namer.new_symbol('lscope', + scope.referenced) + fn_scope.context_name = function_context_name + anno.setanno(node, 'function_context_name', function_context_name) + + template = """ + ag__.with_function_scope( + lambda function_context: body, function_context_name, options) + """ + node.body = templates.replace_as_expression( + template, + options=self._function_scope_options(fn_scope).to_ast(), + function_context=function_context_name, + function_context_name=gast.Constant(function_context_name, kind=None), + body=node.body) - # Only wrap the top-level function. Theoretically, we can and should wrap - # everything, but that can lead to excessive boilerplate when lambdas are - # nested. - # TODO(mdan): Looks more closely for use cases that actually require this. - if self.state[_Function].level > 2: - self.state[_Function].exit() return node - scope = anno.getanno(node, anno.Static.SCOPE) - function_context_name = self.ctx.namer.new_symbol('lscope', - scope.referenced) - self.state[_Function].context_name = function_context_name - anno.setanno(node, 'function_context_name', function_context_name) - - template = """ - ag__.with_function_scope( - lambda function_context: body, function_context_name, options) - """ - node.body = templates.replace_as_expression( - template, - options=self._function_scope_options().to_ast(), - function_context=function_context_name, - function_context_name=gast.Constant(function_context_name, kind=None), - body=node.body) - - self.state[_Function].exit() - return node - def visit_FunctionDef(self, node): - self.state[_Function].enter() - scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + with self.state[_Function] as fn_scope: + scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - function_context_name = self.ctx.namer.new_symbol('fscope', - scope.referenced) - self.state[_Function].context_name = function_context_name - anno.setanno(node, 'function_context_name', function_context_name) + function_context_name = self.ctx.namer.new_symbol('fscope', + scope.referenced) + fn_scope.context_name = function_context_name + anno.setanno(node, 'function_context_name', function_context_name) - node = self.generic_visit(node) + node = self.generic_visit(node) - docstring_node = None - if node.body: - first_statement = node.body[0] - if (isinstance(first_statement, gast.Expr) and - isinstance(first_statement.value, gast.Constant)): - docstring_node = first_statement - node.body = node.body[1:] + docstring_node = None + if node.body: + first_statement = node.body[0] + if (isinstance(first_statement, gast.Expr) and + isinstance(first_statement.value, gast.Constant)): + docstring_node = first_statement + node.body = node.body[1:] - template = """ - with ag__.FunctionScope( - function_name, context_name, options) as function_context: - body - """ - wrapped_body = templates.replace( - template, - function_name=gast.Constant(node.name, kind=None), - context_name=gast.Constant(function_context_name, kind=None), - options=self._function_scope_options().to_ast(), - function_context=function_context_name, - body=node.body) + template = """ + with ag__.FunctionScope( + function_name, context_name, options) as function_context: + body + """ + wrapped_body = templates.replace( + template, + function_name=gast.Constant(node.name, kind=None), + context_name=gast.Constant(function_context_name, kind=None), + options=self._function_scope_options(fn_scope).to_ast(), + function_context=function_context_name, + body=node.body) - if docstring_node is not None: - wrapped_body = [docstring_node] + wrapped_body + if docstring_node is not None: + wrapped_body = [docstring_node] + wrapped_body - node.body = wrapped_body + node.body = wrapped_body - self.state[_Function].exit() - return node + return node def transform(node, ctx): diff --git a/tensorflow/python/autograph/converters/function_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py index 9c8939a6132..699766b9fbf 100644 --- a/tensorflow/python/autograph/converters/function_scopes_test.py +++ b/tensorflow/python/autograph/converters/function_scopes_test.py @@ -126,6 +126,15 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): self.assertNotIn('inner_fn', first.op.name) self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name) + def test_lambda_in_return_value(self): + + def test_fn(): + return lambda x: x + 1 + + with self.converted(test_fn, function_scopes, {}) as result: + result_l = result.test_fn() + self.assertTrue(result_l.fake_autograph_artifact) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 4b170159b8b..8afcbdfb6bd 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -96,6 +96,10 @@ class TestCase(test.TestCase): kwargs = {} return f(*args, **kwargs) + def fake_autograph_artifact(f): + setattr(f, 'fake_autograph_artifact', True) + return f + try: result, source, source_map = loader.load_ast( node, include_source_map=True) @@ -111,6 +115,7 @@ class TestCase(test.TestCase): fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.FunctionScope = function_wrappers.FunctionScope + fake_ag.autograph_artifact = fake_autograph_artifact result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items():