Mark local lambda functions so that they're not being converted twice.

PiperOrigin-RevId: 304804225
Change-Id: I5fd3951de9f7b4ebac79676dc212a254f970505d
This commit is contained in:
Dan Moldovan 2020-04-04 10:58:35 -07:00 committed by TensorFlower Gardener
parent 305a6e7251
commit ca4745d2ad
3 changed files with 74 additions and 61 deletions

View File

@ -38,12 +38,13 @@ class FunctionBodyTransformer(converter.Base):
def visit_Return(self, node):
if node.value is None:
return node
node = self.generic_visit(node)
return templates.replace(
'return function_context_name.mark_return_value(value)',
function_context_name=self.state[_Function].context_name,
value=node.value)
def _function_scope_options(self):
def _function_scope_options(self, fn_scope):
"""Returns the options with which to create function scopes."""
# Top-level function receive the options that were directly requested.
# All others receive the options corresponding to a recursive conversion.
@ -51,81 +52,79 @@ class FunctionBodyTransformer(converter.Base):
# primarily because the FunctionScope context also creates a
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
# function_wrappers.py.
if self.state[_Function].level == 2:
if fn_scope.level == 2:
return self.ctx.program.options
return self.ctx.program.options.call_options()
def visit_Lambda(self, node):
self.state[_Function].enter()
node = self.generic_visit(node)
with self.state[_Function] as fn_scope:
node = self.generic_visit(node)
# Only wrap the top-level function. Theoretically, we can and should wrap
# everything, but that can lead to excessive boilerplate when lambdas are
# nested.
# TODO(mdan): Looks more closely for use cases that actually require this.
if fn_scope.level > 2:
return templates.replace_as_expression(
'ag__.autograph_artifact(l)', l=node)
scope = anno.getanno(node, anno.Static.SCOPE)
function_context_name = self.ctx.namer.new_symbol('lscope',
scope.referenced)
fn_scope.context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
template = """
ag__.with_function_scope(
lambda function_context: body, function_context_name, options)
"""
node.body = templates.replace_as_expression(
template,
options=self._function_scope_options(fn_scope).to_ast(),
function_context=function_context_name,
function_context_name=gast.Constant(function_context_name, kind=None),
body=node.body)
# Only wrap the top-level function. Theoretically, we can and should wrap
# everything, but that can lead to excessive boilerplate when lambdas are
# nested.
# TODO(mdan): Looks more closely for use cases that actually require this.
if self.state[_Function].level > 2:
self.state[_Function].exit()
return node
scope = anno.getanno(node, anno.Static.SCOPE)
function_context_name = self.ctx.namer.new_symbol('lscope',
scope.referenced)
self.state[_Function].context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
template = """
ag__.with_function_scope(
lambda function_context: body, function_context_name, options)
"""
node.body = templates.replace_as_expression(
template,
options=self._function_scope_options().to_ast(),
function_context=function_context_name,
function_context_name=gast.Constant(function_context_name, kind=None),
body=node.body)
self.state[_Function].exit()
return node
def visit_FunctionDef(self, node):
self.state[_Function].enter()
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
with self.state[_Function] as fn_scope:
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
function_context_name = self.ctx.namer.new_symbol('fscope',
scope.referenced)
self.state[_Function].context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
function_context_name = self.ctx.namer.new_symbol('fscope',
scope.referenced)
fn_scope.context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
node = self.generic_visit(node)
node = self.generic_visit(node)
docstring_node = None
if node.body:
first_statement = node.body[0]
if (isinstance(first_statement, gast.Expr) and
isinstance(first_statement.value, gast.Constant)):
docstring_node = first_statement
node.body = node.body[1:]
docstring_node = None
if node.body:
first_statement = node.body[0]
if (isinstance(first_statement, gast.Expr) and
isinstance(first_statement.value, gast.Constant)):
docstring_node = first_statement
node.body = node.body[1:]
template = """
with ag__.FunctionScope(
function_name, context_name, options) as function_context:
body
"""
wrapped_body = templates.replace(
template,
function_name=gast.Constant(node.name, kind=None),
context_name=gast.Constant(function_context_name, kind=None),
options=self._function_scope_options().to_ast(),
function_context=function_context_name,
body=node.body)
template = """
with ag__.FunctionScope(
function_name, context_name, options) as function_context:
body
"""
wrapped_body = templates.replace(
template,
function_name=gast.Constant(node.name, kind=None),
context_name=gast.Constant(function_context_name, kind=None),
options=self._function_scope_options(fn_scope).to_ast(),
function_context=function_context_name,
body=node.body)
if docstring_node is not None:
wrapped_body = [docstring_node] + wrapped_body
if docstring_node is not None:
wrapped_body = [docstring_node] + wrapped_body
node.body = wrapped_body
node.body = wrapped_body
self.state[_Function].exit()
return node
return node
def transform(node, ctx):

View File

@ -126,6 +126,15 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
self.assertNotIn('inner_fn', first.op.name)
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
def test_lambda_in_return_value(self):
def test_fn():
return lambda x: x + 1
with self.converted(test_fn, function_scopes, {}) as result:
result_l = result.test_fn()
self.assertTrue(result_l.fake_autograph_artifact)
if __name__ == '__main__':
test.main()

View File

@ -96,6 +96,10 @@ class TestCase(test.TestCase):
kwargs = {}
return f(*args, **kwargs)
def fake_autograph_artifact(f):
setattr(f, 'fake_autograph_artifact', True)
return f
try:
result, source, source_map = loader.load_ast(
node, include_source_map=True)
@ -111,6 +115,7 @@ class TestCase(test.TestCase):
fake_ag.Feature = converter.Feature
fake_ag.utils = utils
fake_ag.FunctionScope = function_wrappers.FunctionScope
fake_ag.autograph_artifact = fake_autograph_artifact
result.ag__ = fake_ag
result.ag_source_map__ = source_map
for k, v in namespace.items():