Account more thoroughly for the hidden extra_test code in activity analysis. This avoids related referenced-before-assigned errors. Should fix #36462, but that's most practical to verify after the nightly build.

PiperOrigin-RevId: 297737367
Change-Id: I8ae0e462d1a337031ca5ee89888115b49388826c
This commit is contained in:
Dan Moldovan 2020-02-27 18:14:26 -08:00 committed by TensorFlower Gardener
parent cee4c00759
commit a962580295
8 changed files with 24 additions and 11 deletions

View File

@ -143,7 +143,7 @@ class BreakTransformer(converter.Base):
orelse=guarded_orelse) orelse=guarded_orelse)
new_for_node = node[1] new_for_node = node[1]
anno.setanno(new_for_node, 'extra_test', extra_test) anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
return node return node

View File

@ -462,8 +462,8 @@ class ControlFlowTransformer(converter.Base):
opts = self._create_loop_options(node) opts = self._create_loop_options(node)
if anno.hasanno(node, 'extra_test'): if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
extra_test = anno.getanno(node, 'extra_test') extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
extra_test_name = self.ctx.namer.new_symbol( extra_test_name = self.ctx.namer.new_symbol(
'extra_test', reserved_symbols) 'extra_test', reserved_symbols)
template = """ template = """

View File

@ -487,8 +487,8 @@ class ControlFlowTransformer(converter.Base):
state_functions = self._create_state_functions( state_functions = self._create_state_functions(
composite_loop_vars, state_getter_name, state_setter_name) composite_loop_vars, state_getter_name, state_setter_name)
if anno.hasanno(node, 'extra_test'): if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
extra_test = anno.getanno(node, 'extra_test') extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
extra_test_name = self.ctx.namer.new_symbol( extra_test_name = self.ctx.namer.new_symbol(
'extra_test', reserved_symbols) 'extra_test', reserved_symbols)
template = """ template = """

View File

@ -293,7 +293,7 @@ class ReturnStatementsTransformer(converter.Base):
# Add the check for return to the loop condition. # Add the check for return to the loop condition.
node.body = self._visit_statement_block(node, node.body) node.body = self._visit_statement_block(node, node.body)
if self.state[_Block].return_used: if self.state[_Block].return_used:
extra_test = anno.getanno(node, 'extra_test', default=None) extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
if extra_test is not None: if extra_test is not None:
extra_test = templates.replace_as_expression( extra_test = templates.replace_as_expression(
'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
@ -303,7 +303,7 @@ class ReturnStatementsTransformer(converter.Base):
extra_test = templates.replace_as_expression( extra_test = templates.replace_as_expression(
'ag__.not_(control_var)', 'ag__.not_(control_var)',
control_var=self.state[_Function].do_return_var_name) control_var=self.state[_Function].do_return_var_name)
anno.setanno(node, 'extra_test', extra_test) anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
node.orelse = self._visit_statement_block(node, node.orelse) node.orelse = self._visit_statement_block(node, node.orelse)
return node return node
@ -356,16 +356,17 @@ class ReturnStatementsTransformer(converter.Base):
if self.state[_Block].return_used: if self.state[_Block].return_used:
if self.default_to_null_return: if self.default_to_null_return:
# TODO(mdan): Remove the (do_return_var_name,) below.
# Currently, that line ensures the variable is both defined and alive
# throughout the function.
template = """ template = """
do_return_var_name = False do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue() retval_var_name = ag__.UndefinedReturnValue()
body body
# TODO(b/134753123) Remove the do_return_var_name tuple.
(do_return_var_name,) (do_return_var_name,)
return ag__.retval(retval_var_name) return ag__.retval(retval_var_name)
""" """
else: else:
# TODO(b/134753123) Fix loops that return when do_return is not set.
template = """ template = """
body body
return retval_var_name return retval_var_name

View File

@ -59,6 +59,10 @@ class Basic(NoValue):
DIRECTIVES = ('User directives associated with a statement or a variable.' DIRECTIVES = ('User directives associated with a statement or a variable.'
' Typically, they affect the immediately-enclosing statement.') ' Typically, they affect the immediately-enclosing statement.')
EXTRA_LOOP_TEST = (
'A special annotation containing additional test code to be executed in'
' for loops.')
class Static(NoValue): class Static(NoValue):
"""Container for static analysis annotation keys. """Container for static analysis annotation keys.

View File

@ -48,6 +48,7 @@ from enum import Enum
import gast import gast
# pylint:enable=g-bad-import-order # pylint:enable=g-bad-import-order
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
@ -843,6 +844,11 @@ class AstToCfg(gast.NodeVisitor):
# However, the activity analysis accounts for this inconsistency, # However, the activity analysis accounts for this inconsistency,
# so dataflow analysis produces the correct values. # so dataflow analysis produces the correct values.
self.builder.enter_loop_section(node, node.iter) self.builder.enter_loop_section(node, node.iter)
# Also include the "extra loop test" annotation, to capture things like the
# control variable for return and break in for loops.
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
self._process_basic_statement(
anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
for stmt in node.body: for stmt in node.body:
self.visit(stmt) self.visit(stmt)
self.builder.exit_loop_section(node) self.builder.exit_loop_section(node)

View File

@ -539,6 +539,8 @@ class ActivityAnalyzer(transformer.Base):
self._enter_scope(False) self._enter_scope(False)
self.visit(node.target) self.visit(node.target)
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
self._process_statement(anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE)
node = self._process_parallel_blocks(node, node = self._process_parallel_blocks(node,

View File

@ -118,11 +118,11 @@ class ReplaceTransformer(gast.NodeTransformer):
self.replacements = replacements self.replacements = replacements
self.in_replacements = False self.in_replacements = False
self.preserved_annos = { self.preserved_annos = {
anno.Basic.DIRECTIVES,
anno.Basic.EXTRA_LOOP_TEST,
anno.Basic.ORIGIN, anno.Basic.ORIGIN,
anno.Basic.SKIP_PROCESSING, anno.Basic.SKIP_PROCESSING,
anno.Basic.DIRECTIVES,
anno.Static.ORIG_DEFINITIONS, anno.Static.ORIG_DEFINITIONS,
'extra_test',
'function_context_name', 'function_context_name',
} }