diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index 718c5bd3ca5..dc5511824c1 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -143,7 +143,7 @@ class BreakTransformer(converter.Base): orelse=guarded_orelse) new_for_node = node[1] - anno.setanno(new_for_node, 'extra_test', extra_test) + anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 4279631e1a6..10db16ef1bb 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -462,8 +462,8 @@ class ControlFlowTransformer(converter.Base): opts = self._create_loop_options(node) - if anno.hasanno(node, 'extra_test'): - extra_test = anno.getanno(node, 'extra_test') + if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): + extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ diff --git a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py b/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py index 5b1f8bdbb7d..c70460a2413 100644 --- a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py +++ b/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py @@ -487,8 +487,8 @@ class ControlFlowTransformer(converter.Base): state_functions = self._create_state_functions( composite_loop_vars, state_getter_name, state_setter_name) - if anno.hasanno(node, 'extra_test'): - extra_test = anno.getanno(node, 'extra_test') + if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): + extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py index 89f72ce1863..b89d3c13fb7 100644 --- a/tensorflow/python/autograph/converters/return_statements.py +++ b/tensorflow/python/autograph/converters/return_statements.py @@ -293,7 +293,7 @@ class ReturnStatementsTransformer(converter.Base): # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Block].return_used: - extra_test = anno.getanno(node, 'extra_test', default=None) + extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', @@ -303,7 +303,7 @@ class ReturnStatementsTransformer(converter.Base): extra_test = templates.replace_as_expression( 'ag__.not_(control_var)', control_var=self.state[_Function].do_return_var_name) - anno.setanno(node, 'extra_test', extra_test) + anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node @@ -356,16 +356,17 @@ class ReturnStatementsTransformer(converter.Base): if self.state[_Block].return_used: if self.default_to_null_return: + # TODO(mdan): Remove the (do_return_var_name,) below. + # Currently, that line ensures the variable is both defined and alive + # throughout the function. template = """ do_return_var_name = False retval_var_name = ag__.UndefinedReturnValue() body - # TODO(b/134753123) Remove the do_return_var_name tuple. (do_return_var_name,) return ag__.retval(retval_var_name) """ else: - # TODO(b/134753123) Fix loops that return when do_return is not set. template = """ body return retval_var_name diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py index 2a815305475..6fd05f833f5 100644 --- a/tensorflow/python/autograph/pyct/anno.py +++ b/tensorflow/python/autograph/pyct/anno.py @@ -59,6 +59,10 @@ class Basic(NoValue): DIRECTIVES = ('User directives associated with a statement or a variable.' ' Typically, they affect the immediately-enclosing statement.') + EXTRA_LOOP_TEST = ( + 'A special annotation containing additional test code to be executed in' + ' for loops.') + class Static(NoValue): """Container for static analysis annotation keys. diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index 194c39802db..fccaa487543 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -48,6 +48,7 @@ from enum import Enum import gast # pylint:enable=g-bad-import-order +from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import parser @@ -843,6 +844,11 @@ class AstToCfg(gast.NodeVisitor): # However, the activity analysis accounts for this inconsistency, # so dataflow analysis produces the correct values. self.builder.enter_loop_section(node, node.iter) + # Also include the "extra loop test" annotation, to capture things like the + # control variable for return and break in for loops. + if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): + self._process_basic_statement( + anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) for stmt in node.body: self.visit(stmt) self.builder.exit_loop_section(node) diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index 73131d6c0fa..39f2c5d9448 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -539,6 +539,8 @@ class ActivityAnalyzer(transformer.Base): self._enter_scope(False) self.visit(node.target) + if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): + self._process_statement(anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) node = self._process_parallel_blocks(node, diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index c55fee5b85a..b07424b8503 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -118,11 +118,11 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False self.preserved_annos = { + anno.Basic.DIRECTIVES, + anno.Basic.EXTRA_LOOP_TEST, anno.Basic.ORIGIN, anno.Basic.SKIP_PROCESSING, - anno.Basic.DIRECTIVES, anno.Static.ORIG_DEFINITIONS, - 'extra_test', 'function_context_name', }