Replace safeguard against aliasing cond variables with a dynamic check. It is more verbose but more robust in the presence of exception-driven control flow.

PiperOrigin-RevId: 310169792
Change-Id: If4c918c38daea1aa270b7d89d4ea9372707766f0
This commit is contained in:
Dan Moldovan 2020-05-06 09:44:36 -07:00 committed by TensorFlower Gardener
parent 6bc4819e08
commit 8181d041c8
5 changed files with 59 additions and 24 deletions

View File

@ -70,18 +70,33 @@ class ControlFlowTransformer(converter.Base):
return_stmt = templates.replace(template, retvals=returns)
if aliased_orig_names:
alias_declarations = []
for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
template = """
try:
aliased_new_name = aliased_orig_name
except NameError:
aliased_new_name = ag__.Undefined(symbol_name)
"""
alias_declarations.extend(
templates.replace(
template,
aliased_new_name=new_name,
aliased_orig_name=old_name,
symbol_name=gast.Constant(str(old_name), kind=None)))
template = """
def body_name():
aliased_new_names, = aliased_orig_names,
alias_declarations
body
return_stmt
"""
return templates.replace(
template,
alias_declarations=alias_declarations,
body_name=body_name,
body=body,
aliased_orig_names=aliased_orig_names,
aliased_new_names=aliased_new_names,
return_stmt=return_stmt)
else:
template = """
@ -132,13 +147,8 @@ class ControlFlowTransformer(converter.Base):
return 'no variables'
return ', '.join(map(str, symbol_set))
def _determine_aliased_symbols(self, scope, node_defined_in, block):
if block:
block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
else:
block_live_in = set()
modified_live = scope.modified & node_defined_in & block_live_in
def _determine_aliased_symbols(self, scope, node_defined_in):
modified_live = scope.modified & node_defined_in
# Composite symbols are handled elsewhere, see _create_state_functions
return {
s for s in modified_live
@ -221,9 +231,9 @@ class ControlFlowTransformer(converter.Base):
# that happens in the call to generic_visit below, because the conversion
# generates nodes that lack static analysis annotations.
need_alias_in_body = self._determine_aliased_symbols(
body_scope, defined_in, node.body)
body_scope, defined_in)
need_alias_in_orelse = self._determine_aliased_symbols(
orelse_scope, defined_in, node.orelse)
orelse_scope, defined_in)
node = self.generic_visit(node)

View File

@ -57,18 +57,33 @@ class ControlFlowTransformer(converter.Base):
return_stmt = templates.replace(template, retvals=returns)
if aliased_orig_names:
alias_declarations = []
for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
template = """
try:
aliased_new_name = aliased_orig_name
except NameError:
aliased_new_name = ag__.Undefined(symbol_name)
"""
alias_declarations.extend(
templates.replace(
template,
aliased_new_name=new_name,
aliased_orig_name=old_name,
symbol_name=gast.Constant(str(old_name), kind=None)))
template = """
def body_name():
aliased_new_names, = aliased_orig_names,
alias_declarations
body
return_stmt
"""
return templates.replace(
template,
alias_declarations=alias_declarations,
body_name=body_name,
body=body,
aliased_orig_names=aliased_orig_names,
aliased_new_names=aliased_new_names,
return_stmt=return_stmt)
else:
template = """
@ -119,13 +134,8 @@ class ControlFlowTransformer(converter.Base):
return 'no variables'
return ', '.join(map(str, symbol_set))
def _determine_aliased_symbols(self, scope, node_defined_in, block):
if block:
block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
else:
block_live_in = set()
modified_live = scope.modified & node_defined_in & block_live_in
def _determine_aliased_symbols(self, scope, node_defined_in):
modified_live = scope.modified & node_defined_in
# Composite symbols are handled elsewhere see _create_state_functions
return {s for s in modified_live if not s.is_composite()}
@ -196,9 +206,9 @@ class ControlFlowTransformer(converter.Base):
# that happens in the call to generic_visit below, because the conversion
# generates nodes that lack static analysis annotations.
need_alias_in_body = self._determine_aliased_symbols(
body_scope, defined_in, node.body)
body_scope, defined_in)
need_alias_in_orelse = self._determine_aliased_symbols(
orelse_scope, defined_in, node.orelse)
orelse_scope, defined_in)
node = self.generic_visit(node)

View File

@ -121,6 +121,12 @@ class SymbolRenamer(gast.NodeTransformer):
# Renaming attributes is not supported.
return self.generic_visit(node)
def visit_FunctionDef(self, node):
qn = qual_names.QN(node.name)
if qn in self.name_map:
node.name = str(self.name_map[qn])
return self.generic_visit(node)
def rename_symbols(node, name_map):
"""Renames symbols in an AST. Requires qual_names annotations."""

View File

@ -90,6 +90,14 @@ class AstUtilTest(test.TestCase):
self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_function(self):
node = parser.parse('def f():\n pass')
node = ast_util.rename_symbols(node,
{qual_names.QN('f'): qual_names.QN('f1')})
source = parser.unparse(node, include_encoding_marker=False)
self.assertEqual(source.strip(), 'def f1():\n pass')
def test_copy_clean(self):
node = parser.parse(
textwrap.dedent("""

View File

@ -29,6 +29,7 @@ notable exception:
raise (i.e. a function call in the middle of a block does not return or jump
to any except or finally block)
TODO(mdan): Consider adding the edges above. They'd only add ~O(n) edges.
TODO(mdan): Alternatively, consider adding an edge from try to all its excepts.
"""
# TODO(mdan): The notion of 'statements' below is inaccurate.