Replace safeguard against aliasing cond variables with a dynamic check. It is more verbose but more robust in the presence of exception-driven control flow.
PiperOrigin-RevId: 310169792 Change-Id: If4c918c38daea1aa270b7d89d4ea9372707766f0
This commit is contained in:
parent
6bc4819e08
commit
8181d041c8
tensorflow/python/autograph
@ -70,18 +70,33 @@ class ControlFlowTransformer(converter.Base):
|
||||
return_stmt = templates.replace(template, retvals=returns)
|
||||
|
||||
if aliased_orig_names:
|
||||
alias_declarations = []
|
||||
for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
|
||||
template = """
|
||||
try:
|
||||
aliased_new_name = aliased_orig_name
|
||||
except NameError:
|
||||
aliased_new_name = ag__.Undefined(symbol_name)
|
||||
"""
|
||||
|
||||
alias_declarations.extend(
|
||||
templates.replace(
|
||||
template,
|
||||
aliased_new_name=new_name,
|
||||
aliased_orig_name=old_name,
|
||||
symbol_name=gast.Constant(str(old_name), kind=None)))
|
||||
|
||||
template = """
|
||||
def body_name():
|
||||
aliased_new_names, = aliased_orig_names,
|
||||
alias_declarations
|
||||
body
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
alias_declarations=alias_declarations,
|
||||
body_name=body_name,
|
||||
body=body,
|
||||
aliased_orig_names=aliased_orig_names,
|
||||
aliased_new_names=aliased_new_names,
|
||||
return_stmt=return_stmt)
|
||||
else:
|
||||
template = """
|
||||
@ -132,13 +147,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
return 'no variables'
|
||||
return ', '.join(map(str, symbol_set))
|
||||
|
||||
def _determine_aliased_symbols(self, scope, node_defined_in, block):
|
||||
if block:
|
||||
block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
|
||||
else:
|
||||
block_live_in = set()
|
||||
|
||||
modified_live = scope.modified & node_defined_in & block_live_in
|
||||
def _determine_aliased_symbols(self, scope, node_defined_in):
|
||||
modified_live = scope.modified & node_defined_in
|
||||
# Composite symbols are handled elsewhere, see _create_state_functions
|
||||
return {
|
||||
s for s in modified_live
|
||||
@ -221,9 +231,9 @@ class ControlFlowTransformer(converter.Base):
|
||||
# that happens in the call to generic_visit below, because the conversion
|
||||
# generates nodes that lack static analysis annotations.
|
||||
need_alias_in_body = self._determine_aliased_symbols(
|
||||
body_scope, defined_in, node.body)
|
||||
body_scope, defined_in)
|
||||
need_alias_in_orelse = self._determine_aliased_symbols(
|
||||
orelse_scope, defined_in, node.orelse)
|
||||
orelse_scope, defined_in)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
|
@ -57,18 +57,33 @@ class ControlFlowTransformer(converter.Base):
|
||||
return_stmt = templates.replace(template, retvals=returns)
|
||||
|
||||
if aliased_orig_names:
|
||||
alias_declarations = []
|
||||
for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
|
||||
template = """
|
||||
try:
|
||||
aliased_new_name = aliased_orig_name
|
||||
except NameError:
|
||||
aliased_new_name = ag__.Undefined(symbol_name)
|
||||
"""
|
||||
|
||||
alias_declarations.extend(
|
||||
templates.replace(
|
||||
template,
|
||||
aliased_new_name=new_name,
|
||||
aliased_orig_name=old_name,
|
||||
symbol_name=gast.Constant(str(old_name), kind=None)))
|
||||
|
||||
template = """
|
||||
def body_name():
|
||||
aliased_new_names, = aliased_orig_names,
|
||||
alias_declarations
|
||||
body
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
alias_declarations=alias_declarations,
|
||||
body_name=body_name,
|
||||
body=body,
|
||||
aliased_orig_names=aliased_orig_names,
|
||||
aliased_new_names=aliased_new_names,
|
||||
return_stmt=return_stmt)
|
||||
else:
|
||||
template = """
|
||||
@ -119,13 +134,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
return 'no variables'
|
||||
return ', '.join(map(str, symbol_set))
|
||||
|
||||
def _determine_aliased_symbols(self, scope, node_defined_in, block):
|
||||
if block:
|
||||
block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
|
||||
else:
|
||||
block_live_in = set()
|
||||
|
||||
modified_live = scope.modified & node_defined_in & block_live_in
|
||||
def _determine_aliased_symbols(self, scope, node_defined_in):
|
||||
modified_live = scope.modified & node_defined_in
|
||||
# Composite symbols are handled elsewhere see _create_state_functions
|
||||
return {s for s in modified_live if not s.is_composite()}
|
||||
|
||||
@ -196,9 +206,9 @@ class ControlFlowTransformer(converter.Base):
|
||||
# that happens in the call to generic_visit below, because the conversion
|
||||
# generates nodes that lack static analysis annotations.
|
||||
need_alias_in_body = self._determine_aliased_symbols(
|
||||
body_scope, defined_in, node.body)
|
||||
body_scope, defined_in)
|
||||
need_alias_in_orelse = self._determine_aliased_symbols(
|
||||
orelse_scope, defined_in, node.orelse)
|
||||
orelse_scope, defined_in)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
|
@ -121,6 +121,12 @@ class SymbolRenamer(gast.NodeTransformer):
|
||||
# Renaming attributes is not supported.
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
qn = qual_names.QN(node.name)
|
||||
if qn in self.name_map:
|
||||
node.name = str(self.name_map[qn])
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def rename_symbols(node, name_map):
|
||||
"""Renames symbols in an AST. Requires qual_names annotations."""
|
||||
|
@ -90,6 +90,14 @@ class AstUtilTest(test.TestCase):
|
||||
|
||||
self.assertIs(anno.getanno(node, 'foo'), orig_anno)
|
||||
|
||||
def test_rename_symbols_function(self):
|
||||
node = parser.parse('def f():\n pass')
|
||||
node = ast_util.rename_symbols(node,
|
||||
{qual_names.QN('f'): qual_names.QN('f1')})
|
||||
|
||||
source = parser.unparse(node, include_encoding_marker=False)
|
||||
self.assertEqual(source.strip(), 'def f1():\n pass')
|
||||
|
||||
def test_copy_clean(self):
|
||||
node = parser.parse(
|
||||
textwrap.dedent("""
|
||||
|
@ -29,6 +29,7 @@ notable exception:
|
||||
raise (i.e. a function call in the middle of a block does not return or jump
|
||||
to any except or finally block)
|
||||
TODO(mdan): Consider adding the edges above. They'd only add ~O(n) edges.
|
||||
TODO(mdan): Alternatively, consider adding an edge from try to all its excepts.
|
||||
"""
|
||||
|
||||
# TODO(mdan): The notion of 'statements' below is inaccurate.
|
||||
|
Loading…
Reference in New Issue
Block a user