Account more thoroughly for the hidden extra_test code in activity analysis. This avoids related referenced-before-assigned errors. Should fix #36462, but that's most practical to verify after the nightly build.
PiperOrigin-RevId: 297737367 Change-Id: I8ae0e462d1a337031ca5ee89888115b49388826c
This commit is contained in:
		
							parent
							
								
									cee4c00759
								
							
						
					
					
						commit
						a962580295
					
				| @ -143,7 +143,7 @@ class BreakTransformer(converter.Base): | |||||||
|           orelse=guarded_orelse) |           orelse=guarded_orelse) | ||||||
| 
 | 
 | ||||||
|       new_for_node = node[1] |       new_for_node = node[1] | ||||||
|       anno.setanno(new_for_node, 'extra_test', extra_test) |       anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) | ||||||
|       anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) |       anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) | ||||||
| 
 | 
 | ||||||
|     return node |     return node | ||||||
|  | |||||||
| @ -462,8 +462,8 @@ class ControlFlowTransformer(converter.Base): | |||||||
| 
 | 
 | ||||||
|     opts = self._create_loop_options(node) |     opts = self._create_loop_options(node) | ||||||
| 
 | 
 | ||||||
|     if anno.hasanno(node, 'extra_test'): |     if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): | ||||||
|       extra_test = anno.getanno(node, 'extra_test') |       extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) | ||||||
|       extra_test_name = self.ctx.namer.new_symbol( |       extra_test_name = self.ctx.namer.new_symbol( | ||||||
|           'extra_test', reserved_symbols) |           'extra_test', reserved_symbols) | ||||||
|       template = """ |       template = """ | ||||||
|  | |||||||
| @ -487,8 +487,8 @@ class ControlFlowTransformer(converter.Base): | |||||||
|     state_functions = self._create_state_functions( |     state_functions = self._create_state_functions( | ||||||
|         composite_loop_vars, state_getter_name, state_setter_name) |         composite_loop_vars, state_getter_name, state_setter_name) | ||||||
| 
 | 
 | ||||||
|     if anno.hasanno(node, 'extra_test'): |     if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): | ||||||
|       extra_test = anno.getanno(node, 'extra_test') |       extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) | ||||||
|       extra_test_name = self.ctx.namer.new_symbol( |       extra_test_name = self.ctx.namer.new_symbol( | ||||||
|           'extra_test', reserved_symbols) |           'extra_test', reserved_symbols) | ||||||
|       template = """ |       template = """ | ||||||
|  | |||||||
| @ -293,7 +293,7 @@ class ReturnStatementsTransformer(converter.Base): | |||||||
|     # Add the check for return to the loop condition. |     # Add the check for return to the loop condition. | ||||||
|     node.body = self._visit_statement_block(node, node.body) |     node.body = self._visit_statement_block(node, node.body) | ||||||
|     if self.state[_Block].return_used: |     if self.state[_Block].return_used: | ||||||
|       extra_test = anno.getanno(node, 'extra_test', default=None) |       extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) | ||||||
|       if extra_test is not None: |       if extra_test is not None: | ||||||
|         extra_test = templates.replace_as_expression( |         extra_test = templates.replace_as_expression( | ||||||
|             'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', |             'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', | ||||||
| @ -303,7 +303,7 @@ class ReturnStatementsTransformer(converter.Base): | |||||||
|         extra_test = templates.replace_as_expression( |         extra_test = templates.replace_as_expression( | ||||||
|             'ag__.not_(control_var)', |             'ag__.not_(control_var)', | ||||||
|             control_var=self.state[_Function].do_return_var_name) |             control_var=self.state[_Function].do_return_var_name) | ||||||
|       anno.setanno(node, 'extra_test', extra_test) |       anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test) | ||||||
| 
 | 
 | ||||||
|     node.orelse = self._visit_statement_block(node, node.orelse) |     node.orelse = self._visit_statement_block(node, node.orelse) | ||||||
|     return node |     return node | ||||||
| @ -356,16 +356,17 @@ class ReturnStatementsTransformer(converter.Base): | |||||||
|     if self.state[_Block].return_used: |     if self.state[_Block].return_used: | ||||||
| 
 | 
 | ||||||
|       if self.default_to_null_return: |       if self.default_to_null_return: | ||||||
|  |         # TODO(mdan): Remove the (do_return_var_name,) below. | ||||||
|  |         # Currently, that line ensures the variable is both defined and alive | ||||||
|  |         # throughout the function. | ||||||
|         template = """ |         template = """ | ||||||
|           do_return_var_name = False |           do_return_var_name = False | ||||||
|           retval_var_name = ag__.UndefinedReturnValue() |           retval_var_name = ag__.UndefinedReturnValue() | ||||||
|           body |           body | ||||||
|           # TODO(b/134753123) Remove the do_return_var_name tuple. |  | ||||||
|           (do_return_var_name,) |           (do_return_var_name,) | ||||||
|           return ag__.retval(retval_var_name) |           return ag__.retval(retval_var_name) | ||||||
|         """ |         """ | ||||||
|       else: |       else: | ||||||
|         # TODO(b/134753123) Fix loops that return when do_return is not set. |  | ||||||
|         template = """ |         template = """ | ||||||
|           body |           body | ||||||
|           return retval_var_name |           return retval_var_name | ||||||
|  | |||||||
| @ -59,6 +59,10 @@ class Basic(NoValue): | |||||||
|   DIRECTIVES = ('User directives associated with a statement or a variable.' |   DIRECTIVES = ('User directives associated with a statement or a variable.' | ||||||
|                 ' Typically, they affect the immediately-enclosing statement.') |                 ' Typically, they affect the immediately-enclosing statement.') | ||||||
| 
 | 
 | ||||||
|  |   EXTRA_LOOP_TEST = ( | ||||||
|  |       'A special annotation containing additional test code to be executed in' | ||||||
|  |       ' for loops.') | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class Static(NoValue): | class Static(NoValue): | ||||||
|   """Container for static analysis annotation keys. |   """Container for static analysis annotation keys. | ||||||
|  | |||||||
| @ -48,6 +48,7 @@ from enum import Enum | |||||||
| import gast | import gast | ||||||
| # pylint:enable=g-bad-import-order | # pylint:enable=g-bad-import-order | ||||||
| 
 | 
 | ||||||
|  | from tensorflow.python.autograph.pyct import anno | ||||||
| from tensorflow.python.autograph.pyct import parser | from tensorflow.python.autograph.pyct import parser | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -843,6 +844,11 @@ class AstToCfg(gast.NodeVisitor): | |||||||
|     # However, the activity analysis accounts for this inconsistency, |     # However, the activity analysis accounts for this inconsistency, | ||||||
|     # so dataflow analysis produces the correct values. |     # so dataflow analysis produces the correct values. | ||||||
|     self.builder.enter_loop_section(node, node.iter) |     self.builder.enter_loop_section(node, node.iter) | ||||||
|  |     # Also include the "extra loop test" annotation, to capture things like the | ||||||
|  |     # control variable for return and break in for loops. | ||||||
|  |     if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): | ||||||
|  |       self._process_basic_statement( | ||||||
|  |           anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) | ||||||
|     for stmt in node.body: |     for stmt in node.body: | ||||||
|       self.visit(stmt) |       self.visit(stmt) | ||||||
|     self.builder.exit_loop_section(node) |     self.builder.exit_loop_section(node) | ||||||
|  | |||||||
| @ -539,6 +539,8 @@ class ActivityAnalyzer(transformer.Base): | |||||||
| 
 | 
 | ||||||
|     self._enter_scope(False) |     self._enter_scope(False) | ||||||
|     self.visit(node.target) |     self.visit(node.target) | ||||||
|  |     if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): | ||||||
|  |       self._process_statement(anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) | ||||||
|     self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) |     self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) | ||||||
| 
 | 
 | ||||||
|     node = self._process_parallel_blocks(node, |     node = self._process_parallel_blocks(node, | ||||||
|  | |||||||
| @ -118,11 +118,11 @@ class ReplaceTransformer(gast.NodeTransformer): | |||||||
|     self.replacements = replacements |     self.replacements = replacements | ||||||
|     self.in_replacements = False |     self.in_replacements = False | ||||||
|     self.preserved_annos = { |     self.preserved_annos = { | ||||||
|  |         anno.Basic.DIRECTIVES, | ||||||
|  |         anno.Basic.EXTRA_LOOP_TEST, | ||||||
|         anno.Basic.ORIGIN, |         anno.Basic.ORIGIN, | ||||||
|         anno.Basic.SKIP_PROCESSING, |         anno.Basic.SKIP_PROCESSING, | ||||||
|         anno.Basic.DIRECTIVES, |  | ||||||
|         anno.Static.ORIG_DEFINITIONS, |         anno.Static.ORIG_DEFINITIONS, | ||||||
|         'extra_test', |  | ||||||
|         'function_context_name', |         'function_context_name', | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user