Account more thoroughly for the hidden extra_test code in activity analysis. This avoids related referenced-before-assigned errors. Should fix #36462, but that's most practical to verify after the nightly build.

PiperOrigin-RevId: 297737367
Change-Id: I8ae0e462d1a337031ca5ee89888115b49388826c
This commit is contained in:
Dan Moldovan 2020-02-27 18:14:26 -08:00 committed by TensorFlower Gardener
parent cee4c00759
commit a962580295
8 changed files with 24 additions and 11 deletions

View File

@ -143,7 +143,7 @@ class BreakTransformer(converter.Base):
orelse=guarded_orelse)
new_for_node = node[1]
anno.setanno(new_for_node, 'extra_test', extra_test)
anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
return node

View File

@ -462,8 +462,8 @@ class ControlFlowTransformer(converter.Base):
opts = self._create_loop_options(node)
if anno.hasanno(node, 'extra_test'):
extra_test = anno.getanno(node, 'extra_test')
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
extra_test_name = self.ctx.namer.new_symbol(
'extra_test', reserved_symbols)
template = """

View File

@ -487,8 +487,8 @@ class ControlFlowTransformer(converter.Base):
state_functions = self._create_state_functions(
composite_loop_vars, state_getter_name, state_setter_name)
if anno.hasanno(node, 'extra_test'):
extra_test = anno.getanno(node, 'extra_test')
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
extra_test_name = self.ctx.namer.new_symbol(
'extra_test', reserved_symbols)
template = """

View File

@ -293,7 +293,7 @@ class ReturnStatementsTransformer(converter.Base):
# Add the check for return to the loop condition.
node.body = self._visit_statement_block(node, node.body)
if self.state[_Block].return_used:
extra_test = anno.getanno(node, 'extra_test', default=None)
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
if extra_test is not None:
extra_test = templates.replace_as_expression(
'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
@ -303,7 +303,7 @@ class ReturnStatementsTransformer(converter.Base):
extra_test = templates.replace_as_expression(
'ag__.not_(control_var)',
control_var=self.state[_Function].do_return_var_name)
anno.setanno(node, 'extra_test', extra_test)
anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
node.orelse = self._visit_statement_block(node, node.orelse)
return node
@ -356,16 +356,17 @@ class ReturnStatementsTransformer(converter.Base):
if self.state[_Block].return_used:
if self.default_to_null_return:
# TODO(mdan): Remove the (do_return_var_name,) below.
# Currently, that line ensures the variable is both defined and alive
# throughout the function.
template = """
do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue()
body
# TODO(b/134753123) Remove the do_return_var_name tuple.
(do_return_var_name,)
return ag__.retval(retval_var_name)
"""
else:
# TODO(b/134753123) Fix loops that return when do_return is not set.
template = """
body
return retval_var_name

View File

@ -59,6 +59,10 @@ class Basic(NoValue):
DIRECTIVES = ('User directives associated with a statement or a variable.'
' Typically, they affect the immediately-enclosing statement.')
EXTRA_LOOP_TEST = (
'A special annotation containing additional test code to be executed in'
' for loops.')
class Static(NoValue):
"""Container for static analysis annotation keys.

View File

@ -48,6 +48,7 @@ from enum import Enum
import gast
# pylint:enable=g-bad-import-order
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
@ -843,6 +844,11 @@ class AstToCfg(gast.NodeVisitor):
# However, the activity analysis accounts for this inconsistency,
# so dataflow analysis produces the correct values.
self.builder.enter_loop_section(node, node.iter)
# Also include the "extra loop test" annotation, to capture things like the
# control variable for return and break in for loops.
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
self._process_basic_statement(
anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
for stmt in node.body:
self.visit(stmt)
self.builder.exit_loop_section(node)

View File

@ -539,6 +539,8 @@ class ActivityAnalyzer(transformer.Base):
self._enter_scope(False)
self.visit(node.target)
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
self._process_statement(anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE)
node = self._process_parallel_blocks(node,

View File

@ -118,11 +118,11 @@ class ReplaceTransformer(gast.NodeTransformer):
self.replacements = replacements
self.in_replacements = False
self.preserved_annos = {
anno.Basic.DIRECTIVES,
anno.Basic.EXTRA_LOOP_TEST,
anno.Basic.ORIGIN,
anno.Basic.SKIP_PROCESSING,
anno.Basic.DIRECTIVES,
anno.Static.ORIG_DEFINITIONS,
'extra_test',
'function_context_name',
}