Account more thoroughly for the hidden extra_test code in activity analysis. This avoids related referenced-before-assigned errors. Should fix #36462, but that's most practical to verify after the nightly build.
PiperOrigin-RevId: 297737367 Change-Id: I8ae0e462d1a337031ca5ee89888115b49388826c
This commit is contained in:
parent
cee4c00759
commit
a962580295
@ -143,7 +143,7 @@ class BreakTransformer(converter.Base):
|
||||
orelse=guarded_orelse)
|
||||
|
||||
new_for_node = node[1]
|
||||
anno.setanno(new_for_node, 'extra_test', extra_test)
|
||||
anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
|
||||
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
|
||||
|
||||
return node
|
||||
|
||||
@ -462,8 +462,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
|
||||
opts = self._create_loop_options(node)
|
||||
|
||||
if anno.hasanno(node, 'extra_test'):
|
||||
extra_test = anno.getanno(node, 'extra_test')
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
|
||||
extra_test_name = self.ctx.namer.new_symbol(
|
||||
'extra_test', reserved_symbols)
|
||||
template = """
|
||||
|
||||
@ -487,8 +487,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
state_functions = self._create_state_functions(
|
||||
composite_loop_vars, state_getter_name, state_setter_name)
|
||||
|
||||
if anno.hasanno(node, 'extra_test'):
|
||||
extra_test = anno.getanno(node, 'extra_test')
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
|
||||
extra_test_name = self.ctx.namer.new_symbol(
|
||||
'extra_test', reserved_symbols)
|
||||
template = """
|
||||
|
||||
@ -293,7 +293,7 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
# Add the check for return to the loop condition.
|
||||
node.body = self._visit_statement_block(node, node.body)
|
||||
if self.state[_Block].return_used:
|
||||
extra_test = anno.getanno(node, 'extra_test', default=None)
|
||||
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
|
||||
if extra_test is not None:
|
||||
extra_test = templates.replace_as_expression(
|
||||
'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
|
||||
@ -303,7 +303,7 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
extra_test = templates.replace_as_expression(
|
||||
'ag__.not_(control_var)',
|
||||
control_var=self.state[_Function].do_return_var_name)
|
||||
anno.setanno(node, 'extra_test', extra_test)
|
||||
anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
|
||||
|
||||
node.orelse = self._visit_statement_block(node, node.orelse)
|
||||
return node
|
||||
@ -356,16 +356,17 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
if self.state[_Block].return_used:
|
||||
|
||||
if self.default_to_null_return:
|
||||
# TODO(mdan): Remove the (do_return_var_name,) below.
|
||||
# Currently, that line ensures the variable is both defined and alive
|
||||
# throughout the function.
|
||||
template = """
|
||||
do_return_var_name = False
|
||||
retval_var_name = ag__.UndefinedReturnValue()
|
||||
body
|
||||
# TODO(b/134753123) Remove the do_return_var_name tuple.
|
||||
(do_return_var_name,)
|
||||
return ag__.retval(retval_var_name)
|
||||
"""
|
||||
else:
|
||||
# TODO(b/134753123) Fix loops that return when do_return is not set.
|
||||
template = """
|
||||
body
|
||||
return retval_var_name
|
||||
|
||||
@ -59,6 +59,10 @@ class Basic(NoValue):
|
||||
DIRECTIVES = ('User directives associated with a statement or a variable.'
|
||||
' Typically, they affect the immediately-enclosing statement.')
|
||||
|
||||
EXTRA_LOOP_TEST = (
|
||||
'A special annotation containing additional test code to be executed in'
|
||||
' for loops.')
|
||||
|
||||
|
||||
class Static(NoValue):
|
||||
"""Container for static analysis annotation keys.
|
||||
|
||||
@ -48,6 +48,7 @@ from enum import Enum
|
||||
import gast
|
||||
# pylint:enable=g-bad-import-order
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
|
||||
|
||||
@ -843,6 +844,11 @@ class AstToCfg(gast.NodeVisitor):
|
||||
# However, the activity analysis accounts for this inconsistency,
|
||||
# so dataflow analysis produces the correct values.
|
||||
self.builder.enter_loop_section(node, node.iter)
|
||||
# Also include the "extra loop test" annotation, to capture things like the
|
||||
# control variable for return and break in for loops.
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
self._process_basic_statement(
|
||||
anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
self.builder.exit_loop_section(node)
|
||||
|
||||
@ -539,6 +539,8 @@ class ActivityAnalyzer(transformer.Base):
|
||||
|
||||
self._enter_scope(False)
|
||||
self.visit(node.target)
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
self._process_statement(anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
|
||||
self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE)
|
||||
|
||||
node = self._process_parallel_blocks(node,
|
||||
|
||||
@ -118,11 +118,11 @@ class ReplaceTransformer(gast.NodeTransformer):
|
||||
self.replacements = replacements
|
||||
self.in_replacements = False
|
||||
self.preserved_annos = {
|
||||
anno.Basic.DIRECTIVES,
|
||||
anno.Basic.EXTRA_LOOP_TEST,
|
||||
anno.Basic.ORIGIN,
|
||||
anno.Basic.SKIP_PROCESSING,
|
||||
anno.Basic.DIRECTIVES,
|
||||
anno.Static.ORIG_DEFINITIONS,
|
||||
'extra_test',
|
||||
'function_context_name',
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user