Mark local lambda functions so that they're not being converted twice.
PiperOrigin-RevId: 304804225 Change-Id: I5fd3951de9f7b4ebac79676dc212a254f970505d
This commit is contained in:
parent
305a6e7251
commit
ca4745d2ad
@ -38,12 +38,13 @@ class FunctionBodyTransformer(converter.Base):
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
return node
|
||||
node = self.generic_visit(node)
|
||||
return templates.replace(
|
||||
'return function_context_name.mark_return_value(value)',
|
||||
function_context_name=self.state[_Function].context_name,
|
||||
value=node.value)
|
||||
|
||||
def _function_scope_options(self):
|
||||
def _function_scope_options(self, fn_scope):
|
||||
"""Returns the options with which to create function scopes."""
|
||||
# Top-level function receive the options that were directly requested.
|
||||
# All others receive the options corresponding to a recursive conversion.
|
||||
@ -51,81 +52,79 @@ class FunctionBodyTransformer(converter.Base):
|
||||
# primarily because the FunctionScope context also creates a
|
||||
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
||||
# function_wrappers.py.
|
||||
if self.state[_Function].level == 2:
|
||||
if fn_scope.level == 2:
|
||||
return self.ctx.program.options
|
||||
return self.ctx.program.options.call_options()
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
self.state[_Function].enter()
|
||||
node = self.generic_visit(node)
|
||||
with self.state[_Function] as fn_scope:
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# Only wrap the top-level function. Theoretically, we can and should wrap
|
||||
# everything, but that can lead to excessive boilerplate when lambdas are
|
||||
# nested.
|
||||
# TODO(mdan): Looks more closely for use cases that actually require this.
|
||||
if fn_scope.level > 2:
|
||||
return templates.replace_as_expression(
|
||||
'ag__.autograph_artifact(l)', l=node)
|
||||
|
||||
scope = anno.getanno(node, anno.Static.SCOPE)
|
||||
function_context_name = self.ctx.namer.new_symbol('lscope',
|
||||
scope.referenced)
|
||||
fn_scope.context_name = function_context_name
|
||||
anno.setanno(node, 'function_context_name', function_context_name)
|
||||
|
||||
template = """
|
||||
ag__.with_function_scope(
|
||||
lambda function_context: body, function_context_name, options)
|
||||
"""
|
||||
node.body = templates.replace_as_expression(
|
||||
template,
|
||||
options=self._function_scope_options(fn_scope).to_ast(),
|
||||
function_context=function_context_name,
|
||||
function_context_name=gast.Constant(function_context_name, kind=None),
|
||||
body=node.body)
|
||||
|
||||
# Only wrap the top-level function. Theoretically, we can and should wrap
|
||||
# everything, but that can lead to excessive boilerplate when lambdas are
|
||||
# nested.
|
||||
# TODO(mdan): Looks more closely for use cases that actually require this.
|
||||
if self.state[_Function].level > 2:
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
|
||||
scope = anno.getanno(node, anno.Static.SCOPE)
|
||||
function_context_name = self.ctx.namer.new_symbol('lscope',
|
||||
scope.referenced)
|
||||
self.state[_Function].context_name = function_context_name
|
||||
anno.setanno(node, 'function_context_name', function_context_name)
|
||||
|
||||
template = """
|
||||
ag__.with_function_scope(
|
||||
lambda function_context: body, function_context_name, options)
|
||||
"""
|
||||
node.body = templates.replace_as_expression(
|
||||
template,
|
||||
options=self._function_scope_options().to_ast(),
|
||||
function_context=function_context_name,
|
||||
function_context_name=gast.Constant(function_context_name, kind=None),
|
||||
body=node.body)
|
||||
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.state[_Function].enter()
|
||||
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
with self.state[_Function] as fn_scope:
|
||||
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
|
||||
function_context_name = self.ctx.namer.new_symbol('fscope',
|
||||
scope.referenced)
|
||||
self.state[_Function].context_name = function_context_name
|
||||
anno.setanno(node, 'function_context_name', function_context_name)
|
||||
function_context_name = self.ctx.namer.new_symbol('fscope',
|
||||
scope.referenced)
|
||||
fn_scope.context_name = function_context_name
|
||||
anno.setanno(node, 'function_context_name', function_context_name)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
node = self.generic_visit(node)
|
||||
|
||||
docstring_node = None
|
||||
if node.body:
|
||||
first_statement = node.body[0]
|
||||
if (isinstance(first_statement, gast.Expr) and
|
||||
isinstance(first_statement.value, gast.Constant)):
|
||||
docstring_node = first_statement
|
||||
node.body = node.body[1:]
|
||||
docstring_node = None
|
||||
if node.body:
|
||||
first_statement = node.body[0]
|
||||
if (isinstance(first_statement, gast.Expr) and
|
||||
isinstance(first_statement.value, gast.Constant)):
|
||||
docstring_node = first_statement
|
||||
node.body = node.body[1:]
|
||||
|
||||
template = """
|
||||
with ag__.FunctionScope(
|
||||
function_name, context_name, options) as function_context:
|
||||
body
|
||||
"""
|
||||
wrapped_body = templates.replace(
|
||||
template,
|
||||
function_name=gast.Constant(node.name, kind=None),
|
||||
context_name=gast.Constant(function_context_name, kind=None),
|
||||
options=self._function_scope_options().to_ast(),
|
||||
function_context=function_context_name,
|
||||
body=node.body)
|
||||
template = """
|
||||
with ag__.FunctionScope(
|
||||
function_name, context_name, options) as function_context:
|
||||
body
|
||||
"""
|
||||
wrapped_body = templates.replace(
|
||||
template,
|
||||
function_name=gast.Constant(node.name, kind=None),
|
||||
context_name=gast.Constant(function_context_name, kind=None),
|
||||
options=self._function_scope_options(fn_scope).to_ast(),
|
||||
function_context=function_context_name,
|
||||
body=node.body)
|
||||
|
||||
if docstring_node is not None:
|
||||
wrapped_body = [docstring_node] + wrapped_body
|
||||
if docstring_node is not None:
|
||||
wrapped_body = [docstring_node] + wrapped_body
|
||||
|
||||
node.body = wrapped_body
|
||||
node.body = wrapped_body
|
||||
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
return node
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
|
@ -126,6 +126,15 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
self.assertNotIn('inner_fn', first.op.name)
|
||||
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
|
||||
|
||||
def test_lambda_in_return_value(self):
|
||||
|
||||
def test_fn():
|
||||
return lambda x: x + 1
|
||||
|
||||
with self.converted(test_fn, function_scopes, {}) as result:
|
||||
result_l = result.test_fn()
|
||||
self.assertTrue(result_l.fake_autograph_artifact)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -96,6 +96,10 @@ class TestCase(test.TestCase):
|
||||
kwargs = {}
|
||||
return f(*args, **kwargs)
|
||||
|
||||
def fake_autograph_artifact(f):
|
||||
setattr(f, 'fake_autograph_artifact', True)
|
||||
return f
|
||||
|
||||
try:
|
||||
result, source, source_map = loader.load_ast(
|
||||
node, include_source_map=True)
|
||||
@ -111,6 +115,7 @@ class TestCase(test.TestCase):
|
||||
fake_ag.Feature = converter.Feature
|
||||
fake_ag.utils = utils
|
||||
fake_ag.FunctionScope = function_wrappers.FunctionScope
|
||||
fake_ag.autograph_artifact = fake_autograph_artifact
|
||||
result.ag__ = fake_ag
|
||||
result.ag_source_map__ = source_map
|
||||
for k, v in namespace.items():
|
||||
|
Loading…
Reference in New Issue
Block a user