Mark local lambda functions so that they're not being converted twice.
PiperOrigin-RevId: 304804225 Change-Id: I5fd3951de9f7b4ebac79676dc212a254f970505d
This commit is contained in:
parent
305a6e7251
commit
ca4745d2ad
@ -38,12 +38,13 @@ class FunctionBodyTransformer(converter.Base):
|
|||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
return node
|
return node
|
||||||
|
node = self.generic_visit(node)
|
||||||
return templates.replace(
|
return templates.replace(
|
||||||
'return function_context_name.mark_return_value(value)',
|
'return function_context_name.mark_return_value(value)',
|
||||||
function_context_name=self.state[_Function].context_name,
|
function_context_name=self.state[_Function].context_name,
|
||||||
value=node.value)
|
value=node.value)
|
||||||
|
|
||||||
def _function_scope_options(self):
|
def _function_scope_options(self, fn_scope):
|
||||||
"""Returns the options with which to create function scopes."""
|
"""Returns the options with which to create function scopes."""
|
||||||
# Top-level function receive the options that were directly requested.
|
# Top-level function receive the options that were directly requested.
|
||||||
# All others receive the options corresponding to a recursive conversion.
|
# All others receive the options corresponding to a recursive conversion.
|
||||||
@ -51,81 +52,79 @@ class FunctionBodyTransformer(converter.Base):
|
|||||||
# primarily because the FunctionScope context also creates a
|
# primarily because the FunctionScope context also creates a
|
||||||
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
||||||
# function_wrappers.py.
|
# function_wrappers.py.
|
||||||
if self.state[_Function].level == 2:
|
if fn_scope.level == 2:
|
||||||
return self.ctx.program.options
|
return self.ctx.program.options
|
||||||
return self.ctx.program.options.call_options()
|
return self.ctx.program.options.call_options()
|
||||||
|
|
||||||
def visit_Lambda(self, node):
|
def visit_Lambda(self, node):
|
||||||
self.state[_Function].enter()
|
with self.state[_Function] as fn_scope:
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
|
|
||||||
|
# Only wrap the top-level function. Theoretically, we can and should wrap
|
||||||
|
# everything, but that can lead to excessive boilerplate when lambdas are
|
||||||
|
# nested.
|
||||||
|
# TODO(mdan): Looks more closely for use cases that actually require this.
|
||||||
|
if fn_scope.level > 2:
|
||||||
|
return templates.replace_as_expression(
|
||||||
|
'ag__.autograph_artifact(l)', l=node)
|
||||||
|
|
||||||
|
scope = anno.getanno(node, anno.Static.SCOPE)
|
||||||
|
function_context_name = self.ctx.namer.new_symbol('lscope',
|
||||||
|
scope.referenced)
|
||||||
|
fn_scope.context_name = function_context_name
|
||||||
|
anno.setanno(node, 'function_context_name', function_context_name)
|
||||||
|
|
||||||
|
template = """
|
||||||
|
ag__.with_function_scope(
|
||||||
|
lambda function_context: body, function_context_name, options)
|
||||||
|
"""
|
||||||
|
node.body = templates.replace_as_expression(
|
||||||
|
template,
|
||||||
|
options=self._function_scope_options(fn_scope).to_ast(),
|
||||||
|
function_context=function_context_name,
|
||||||
|
function_context_name=gast.Constant(function_context_name, kind=None),
|
||||||
|
body=node.body)
|
||||||
|
|
||||||
# Only wrap the top-level function. Theoretically, we can and should wrap
|
|
||||||
# everything, but that can lead to excessive boilerplate when lambdas are
|
|
||||||
# nested.
|
|
||||||
# TODO(mdan): Looks more closely for use cases that actually require this.
|
|
||||||
if self.state[_Function].level > 2:
|
|
||||||
self.state[_Function].exit()
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
scope = anno.getanno(node, anno.Static.SCOPE)
|
|
||||||
function_context_name = self.ctx.namer.new_symbol('lscope',
|
|
||||||
scope.referenced)
|
|
||||||
self.state[_Function].context_name = function_context_name
|
|
||||||
anno.setanno(node, 'function_context_name', function_context_name)
|
|
||||||
|
|
||||||
template = """
|
|
||||||
ag__.with_function_scope(
|
|
||||||
lambda function_context: body, function_context_name, options)
|
|
||||||
"""
|
|
||||||
node.body = templates.replace_as_expression(
|
|
||||||
template,
|
|
||||||
options=self._function_scope_options().to_ast(),
|
|
||||||
function_context=function_context_name,
|
|
||||||
function_context_name=gast.Constant(function_context_name, kind=None),
|
|
||||||
body=node.body)
|
|
||||||
|
|
||||||
self.state[_Function].exit()
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
self.state[_Function].enter()
|
with self.state[_Function] as fn_scope:
|
||||||
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||||
|
|
||||||
function_context_name = self.ctx.namer.new_symbol('fscope',
|
function_context_name = self.ctx.namer.new_symbol('fscope',
|
||||||
scope.referenced)
|
scope.referenced)
|
||||||
self.state[_Function].context_name = function_context_name
|
fn_scope.context_name = function_context_name
|
||||||
anno.setanno(node, 'function_context_name', function_context_name)
|
anno.setanno(node, 'function_context_name', function_context_name)
|
||||||
|
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
|
|
||||||
docstring_node = None
|
docstring_node = None
|
||||||
if node.body:
|
if node.body:
|
||||||
first_statement = node.body[0]
|
first_statement = node.body[0]
|
||||||
if (isinstance(first_statement, gast.Expr) and
|
if (isinstance(first_statement, gast.Expr) and
|
||||||
isinstance(first_statement.value, gast.Constant)):
|
isinstance(first_statement.value, gast.Constant)):
|
||||||
docstring_node = first_statement
|
docstring_node = first_statement
|
||||||
node.body = node.body[1:]
|
node.body = node.body[1:]
|
||||||
|
|
||||||
template = """
|
template = """
|
||||||
with ag__.FunctionScope(
|
with ag__.FunctionScope(
|
||||||
function_name, context_name, options) as function_context:
|
function_name, context_name, options) as function_context:
|
||||||
body
|
body
|
||||||
"""
|
"""
|
||||||
wrapped_body = templates.replace(
|
wrapped_body = templates.replace(
|
||||||
template,
|
template,
|
||||||
function_name=gast.Constant(node.name, kind=None),
|
function_name=gast.Constant(node.name, kind=None),
|
||||||
context_name=gast.Constant(function_context_name, kind=None),
|
context_name=gast.Constant(function_context_name, kind=None),
|
||||||
options=self._function_scope_options().to_ast(),
|
options=self._function_scope_options(fn_scope).to_ast(),
|
||||||
function_context=function_context_name,
|
function_context=function_context_name,
|
||||||
body=node.body)
|
body=node.body)
|
||||||
|
|
||||||
if docstring_node is not None:
|
if docstring_node is not None:
|
||||||
wrapped_body = [docstring_node] + wrapped_body
|
wrapped_body = [docstring_node] + wrapped_body
|
||||||
|
|
||||||
node.body = wrapped_body
|
node.body = wrapped_body
|
||||||
|
|
||||||
self.state[_Function].exit()
|
return node
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
def transform(node, ctx):
|
def transform(node, ctx):
|
||||||
|
@ -126,6 +126,15 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
|||||||
self.assertNotIn('inner_fn', first.op.name)
|
self.assertNotIn('inner_fn', first.op.name)
|
||||||
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
|
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
|
||||||
|
|
||||||
|
def test_lambda_in_return_value(self):
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
return lambda x: x + 1
|
||||||
|
|
||||||
|
with self.converted(test_fn, function_scopes, {}) as result:
|
||||||
|
result_l = result.test_fn()
|
||||||
|
self.assertTrue(result_l.fake_autograph_artifact)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -96,6 +96,10 @@ class TestCase(test.TestCase):
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
def fake_autograph_artifact(f):
|
||||||
|
setattr(f, 'fake_autograph_artifact', True)
|
||||||
|
return f
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result, source, source_map = loader.load_ast(
|
result, source, source_map = loader.load_ast(
|
||||||
node, include_source_map=True)
|
node, include_source_map=True)
|
||||||
@ -111,6 +115,7 @@ class TestCase(test.TestCase):
|
|||||||
fake_ag.Feature = converter.Feature
|
fake_ag.Feature = converter.Feature
|
||||||
fake_ag.utils = utils
|
fake_ag.utils = utils
|
||||||
fake_ag.FunctionScope = function_wrappers.FunctionScope
|
fake_ag.FunctionScope = function_wrappers.FunctionScope
|
||||||
|
fake_ag.autograph_artifact = fake_autograph_artifact
|
||||||
result.ag__ = fake_ag
|
result.ag__ = fake_ag
|
||||||
result.ag_source_map__ = source_map
|
result.ag_source_map__ = source_map
|
||||||
for k, v in namespace.items():
|
for k, v in namespace.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user