diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 9f2604dec94..2909cf3f8bc 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -21,6 +21,7 @@ from __future__ import print_function import contextlib import imp import inspect +import os import sys import six @@ -100,6 +101,7 @@ class TestCase(test.TestCase): def setUp(self): # AutoGraph tests must run in graph mode to properly test control flow. + os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1' self.graph = ops.Graph().as_default() self.graph.__enter__() diff --git a/tensorflow/python/autograph/g3doc/reference/limitations.md b/tensorflow/python/autograph/g3doc/reference/limitations.md index 70ce5fc7dec..70e3b3a552e 100644 --- a/tensorflow/python/autograph/g3doc/reference/limitations.md +++ b/tensorflow/python/autograph/g3doc/reference/limitations.md @@ -66,22 +66,48 @@ else: pass ``` -Similarly, variables may not be defined inside a TensorFlow loop, unless they -are local to the loop. A variable is local to the loop if (1) it's not used -after the loop and (2) the value from a previour iteration is not used in the -next iteration: +Similarly, variables must usually be defined before a TensorFlow loop. + +The most common example that is not allowed is a loop which initializes some +accumulator variable in the first iteration: ``` del x -while tf.random.uniform(()) > 0.5: # Error -- x must be defined before the loop +for i in tf.range(100): # Error -- x must be defined before the loop + if i == 0: + x = tf.constant(1) + else: + x = x + 1 +tf.print(x) +``` + +When the variable is only used inside the loop and does not depend on previous +iterations, then it's ok to only be initialized inside the loop. + +``` +del x +while tf.random.uniform(()) > 0.5: # Okay -- x is not used after the loop + x = tf.constant(1) +``` + +* New in TF 2.4 * + +As long as it doesn't depend on previous iterations, the variable may also be +used after the loop, however in that case the loop must execute at least one +iteration, and will raise a runtime error otherwise. + +``` +del x +for i in tf.range(10): # Okay -- x does not depend on previous iterations x = tf.constant(1) tf.print(x) ``` ``` del x -while tf.random.uniform(()) > 0.5: # Okay -- x is local to the loop +while tf.constant(False): # Error -- loop must initialize x! x = tf.constant(1) +tf.print(x) ``` Avoid these limitations by defining a default value before the control flow @@ -98,6 +124,34 @@ Note: `None` values and undefined symbols are allowed in Eager control flow, because Eager execution uses Python control flow, rather than TensorFlow control flow ops. +#### Special case: creating Tensors in a loop + +* New in TF 2.4 * + +A very common use-case is to run a training loop that creates some outputs: + +``` +for i in tf.range(num_steps): + outputs = train(next(data_iterator)) +``` + +Often times these outputs can be nested structures of Tensors, which makes them +impractical to initialize ahead of the loop. + +To help with this use-case, AutoGraph lets you run such loops, under certain +conditions: + + * outputs must be a Tensor, Python numeric, or a structure of these + * outputs must not depend on the value from a previous iteration; in other + words, the outputs may only appear to the left of an assignment operation + * the loop must run at least one iteration + +If the type of outputs is not recognized, then the usual +"outputs must be defined before the loop" is raised at graph construction. + +AutoGraph also inserts a `tf.Assert` statement that raises a runtime error +if the loop did not execute at least one iteration. + ### Indirect modifications and hidden side effects in TensorFlow control flow Key Point: We recommend using a functional programming style, immutable Python diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 03f67d67fee..3418450e813 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -60,6 +60,7 @@ from __future__ import division from __future__ import print_function import functools +import os import traceback import numpy as np @@ -79,6 +80,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -99,19 +101,70 @@ INEFFICIENT_UNROLL_MIN_OPS = 1 # datasets. Before it can be used though, we need to standardize the interface. -def _verify_loop_init_vars(values, symbol_names): - """Ensures that all values in the state are defined when entering a loop.""" - for name, value in zip(symbol_names, values): - if value is None: - raise ValueError("'{}' may not be None before the loop.".format(name)) - if isinstance(value, variables.UndefinedReturnValue): - # Assumption: the loop will only capture the variable which tracks the - # return value if the loop contained a return statement. - # TODO(mdan): This should be checked at the place where return occurs. - raise ValueError( - 'return statements are not supported within a TensorFlow loop.') - if isinstance(value, variables.Undefined): - raise ValueError("'{}' must be defined before the loop.".format(name)) +def _is_none_or_undef(value): + """Tests whether a value is None or undefined. + + AutoGraph represents undefined symbols using special objects of type Undefined + or UndefinedReturnValue. + + Args: + value: value to test + Returns: + Boolean + """ + return ((value is None) + or isinstance(value, variables.UndefinedReturnValue) + or isinstance(value, variables.Undefined)) + + +def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None): + """Ensures that all values in the state are valid to use in a TF loop. + + The init_vars may contain placeholder values derived from first_iter_vars. + + Args: + init_vars: initial loop variables (as taken before entering the loop) + symbol_names: corresponding names of the initial loop variables + first_iter_vars: loop variables after one iteration of the loop + """ + if not symbol_names: + return + if first_iter_vars is None: + first_iter_vars = (None,) * len(symbol_names) + + assert len(symbol_names) == len(init_vars) + assert len(symbol_names) == len(first_iter_vars) + for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars): + if isinstance(val, variables.UndefinedReturnValue): + if fi_val: + raise ValueError( + 'the return value from a TensorFlow loop may only be a {}; got {}' + .format(LEGAL_LOOP_TYPES, type(fi_val))) + else: + # TODO(mdan): This can be handled by removing the return value. + raise NotImplementedError( + 'a return statement cannot be placed inside this TensorFlow loop;' + ' this may happen if a return statement depends on a' + ' static Python condition such as a hyperparameter') + + error_msg = None + if val is None: + error_msg = "'{}' may not be None before the loop".format(name) + elif isinstance(val, variables.Undefined): + error_msg = "'{}' must be defined before the loop".format(name) + + # This only happens when we could not infer a placeholder for the + # variable. The canonical case when that happens is when _placeholder_value + # couldnot infer a placeholder for it. That means it's of an unknown type + # or it's still undefined after staging one iteration. + if error_msg is not None: + if fi_val: + error_msg += (", unless it's a {}; got {}".format( + LEGAL_LOOP_TYPES, type(fi_val))) + else: + # TODO(mdan): This can be handled by removing the loop var. + error_msg += '.' + raise ValueError(error_msg) def _is_subshape(left, right): @@ -876,21 +929,134 @@ def _shape_invariants_mapping_to_positional_list(mapping, keys): return tuple(result) +# Textual description of what a legal TF loop variable is. This description +# summarizes types that _placeholder_value below can handle. Keep the two +# together and in sync. +LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof' + + +def _placeholder_value(like, original): + if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)): + return original + if isinstance(like, (int, float, bool)): + return type(like)(0) + if tensor_util.is_tensor(like): + return array_ops.zeros(like.shape, like.dtype) + elif isinstance(like, (list, tuple, dict)): + return nest.map_structure(_placeholder_value, like) + return original + + +def _try_handling_undefineds( + body, get_state, set_state, init_vars, nulls, symbol_names): + """Makes a best-effort attempt to substitute undefineds with placeholders. + + Note: this substitution requires two things to happen: + 1. the types of loop variables could be inferred (usually by staging one + iteration) + 2. these types could be replaced by placeholders (e.g. zero values, for + tensors. + + Args: + body: a function representing the loop body. See while_stmt. + get_state: state getter for the loop statement. See while_stmt. + set_state: state getter for the loop statement. See while_stmt. + init_vars: loop variables before entering the loop. See while_stmt. + nulls: list of boolean flags indicating whether the corresponding loop + var is None or undefined. + symbol_names: list of loop variable names. See while_stmt. + Returns: + A tuple (success, new_init_vars). success is a boolean flag indicating + whether types could be successfully inferred (step 1 above). new_init_vars + contains the loop vars, with None or undefined values replaced by + placeholders, where possible (step 2 above). + """ + state_modified = False + + if not os.getenv('AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS', ''): + _verify_loop_init_vars(init_vars, symbol_names) + return False, init_vars + + try: + # Stage an iteration of the loop body in a temporary graph. + with func_graph.FuncGraph('tmp').as_default(): + # This call to set_state helps report nicer error messages when symbols + # are inconsistently used. + set_state(init_vars) + state_modified = True + + body() + first_iter_vars = get_state() + except (UnboundLocalError, TypeError, ValueError, KeyError): + # Fall back to the old functionality. It will likely result in an input + # validation failure. + first_iter_vars = None + finally: + if state_modified: + set_state(init_vars) + + if first_iter_vars is not None: + # Note: the actual placeholder value doesn't matter, because as the staging + # proved, it will be replaced by an actual value before being read. + init_vars = tuple( + (_placeholder_value(iv, v) if n else v) + for v, n, iv in zip(init_vars, nulls, first_iter_vars)) + success = True + else: + success = False + + # This check runs regardless, in case we captured non-Tensor inputs. + _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars) + + return success, init_vars + + +def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars): + """Creates an error message asking for the loop to iterate at least once.""" + var_names = [] + for sn, n, v in zip(symbol_names, nulls, init_vars): + if not n: + continue + if isinstance(v, variables.UndefinedReturnValue): + var_names.append('the function return value') + else: + var_names.append(sn) + var_names = ', '.join(var_names) + return 'loop must iterate at least once to initialize {}'.format(var_names) + + def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): """Overload of while_stmt that stages a TF while_stmt.""" init_vars = get_state() - _verify_loop_init_vars(init_vars, symbol_names) + orig_init_vars = init_vars + + nulls = tuple(_is_none_or_undef(v) for v in init_vars) + if any(nulls): + require_one_iteration, init_vars = _try_handling_undefineds( + body, get_state, set_state, init_vars, nulls, symbol_names) + else: + require_one_iteration = False def aug_test(*loop_vars): + if require_one_iteration: + loop_vars = loop_vars[1:] + set_state(loop_vars) return test() def aug_body(*loop_vars): + if require_one_iteration: + loop_vars = loop_vars[1:] + set_state(loop_vars) body() new_loop_vars = get_state() _verify_tf_loop_vars( init_vars, loop_vars, new_loop_vars, symbol_names, opts) + + if require_one_iteration: + new_loop_vars = (True,) + new_loop_vars + return new_loop_vars if 'shape_invariants' in opts: @@ -904,8 +1070,23 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): # This enforces consistency across versions. while_loop_opts['return_same_structure'] = True + if require_one_iteration: + aug_init_vars = (False,) + init_vars + else: + aug_init_vars = init_vars + final_loop_vars = control_flow_ops.while_loop( - aug_test, aug_body, init_vars, **while_loop_opts) + aug_test, aug_body, aug_init_vars, **while_loop_opts) + + if require_one_iteration: + with ops.control_dependencies([ + control_flow_ops.Assert(final_loop_vars[0], [ + _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars) + ]) + ]): + final_loop_vars = tuple( + array_ops.identity(v) for v in final_loop_vars[1:]) + set_state(final_loop_vars) diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 5f0629a163f..553643956f6 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -86,7 +86,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={'iterate_names': 'i'}) + self.assertEqual(s, (1234,)) + self.assertOpCreated('StatelessWhile') def test_range_tensor_explicit_limit_delta(self): def body(i): @@ -106,7 +108,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={'iterate_names': 'i'}) + self.assertEqual(s, (-171207,)) + self.assertOpCreated('StatelessWhile') def test_range_tensor_explicit_limit_negative_delta(self): def body(i): @@ -126,7 +130,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={'iterate_names': 'i'}) + self.assertEqual(s, (171207,)) + self.assertOpCreated('StatelessWhile') def test_range_tensor_random_delta(self): def body(i): @@ -147,7 +153,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={'iterate_names': 'i'}) + self.assertEqual(s, (1234,)) + self.assertOpCreated('StatelessWhile') def test_range_tensor_random_negative_delta(self): def body(i): @@ -168,7 +176,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={'iterate_names': 'i'}) + self.assertEqual(s, (171207,)) + self.assertOpCreated('StatelessWhile') def test_tensor_with_extra_test_object_vars(self): class MutableObject(object): @@ -194,7 +204,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('state.field_1', 'state.field_2'), opts={}) + self.assertEqual((state.field_1, state.field_2), (6, 6)) + self.assertOpCreated('StatelessWhile') def test_python(self): def body(i): @@ -214,7 +226,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={}) + self.assertEqual(s, 1234) + self.assertNoOpsCreated() def test_python_generator_with_extra_test(self): def new_generator(): @@ -247,6 +261,8 @@ class ForLoopTest(testing.AutoGraphTestCase): self.assertEqual(next(gen), 4) + self.assertNoOpsCreated() + def test_python_generator_with_extra_test_no_iterations(self): def new_generator(): for i in range(5): @@ -275,6 +291,8 @@ class ForLoopTest(testing.AutoGraphTestCase): self.assertEqual(next(gen), 0) + self.assertNoOpsCreated() + def test_tf_dataset(self): def body(i): nonlocal s @@ -293,7 +311,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={}) + self.assertEqual(s, (1234,)) + self.assertOpCreated('ScanDataset') def test_dataset_with_extra_test(self): def body(i): @@ -313,7 +333,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('s',), opts={}) + self.assertEqual(s, (12,)) + self.assertOpCreated('ScanDataset') def test_dataset_with_extra_test_collection_vars(self): def body(i): @@ -335,7 +357,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('l[0]', 's'), opts={}) + self.assertEqual((l[0], s), (3, 3)) + self.assertOpCreated('ScanDataset') def test_dataset_with_extra_test_iteration_limiting(self): def body(it): @@ -356,7 +380,9 @@ class ForLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('i',), opts={}) + self.assertEqual(i, (3,)) + self.assertOpCreated('ScanDataset') def test_tf_dataset_no_loop_vars(self): def body(i): @@ -374,6 +400,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={}) self.assertEqual(v.read_value(), 1234) + self.assertOpCreated('ScanDataset') def test_tf_iterator(self): def body(i): @@ -395,6 +422,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={}) self.assertEqual(s, 1234) + self.assertOpCreated('IteratorGetNextAsOptional') def test_tf_iterator_shape_invariants(self): def body(i): @@ -416,6 +444,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]}) self.assertAllEqual(s, [0, 1, 2, 3, 4]) + self.assertOpCreated('IteratorGetNextAsOptional') def test_tf_iterator_no_loop_vars(self): def body(i): @@ -433,6 +462,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={}) self.assertEqual(v.read_value(), 1234) + self.assertOpCreated('IteratorGetNextAsOptional') def test_tf_ragged_tensor(self): def body(i): @@ -454,6 +484,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={}) self.assertEqual(s, (123,)) + self.assertOpCreated('StatelessWhile') def test_tf_ragged_tensor_higher_dimensional(self): def body(i): @@ -479,6 +510,7 @@ class ForLoopTest(testing.AutoGraphTestCase): opts={}) self.assertEqual(s, (12,)) + self.assertOpCreated('StatelessWhile') def test_tf_ragged_tensor_no_loop_vars(self): v = self.variable('v', 0, dtypes.int32) @@ -497,6 +529,7 @@ class ForLoopTest(testing.AutoGraphTestCase): # Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row). self.assertEqual(v.read_value(), 123) + self.assertOpCreated('While') def _basic_loop(self, init_value, body_fn): def body(i): @@ -540,6 +573,7 @@ class ForLoopTest(testing.AutoGraphTestCase): class WhileLoopTest(testing.AutoGraphTestCase): def test_tensor(self): + def body(): nonlocal i, s s = s * 10 + i @@ -559,8 +593,38 @@ class WhileLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('i', 's'), opts={}) + self.assertEqual(i, 5) self.assertEqual(s, 1234) + self.assertOpCreated('StatelessWhile') + + def test_tensor_creating_variable(self): + + def body(): + nonlocal i, s + i = constant_op.constant(2) + s = i ** 5 + + def set_state(loop_vars): + nonlocal i, s + i, s = loop_vars + + i = variable_operators.Undefined('i') + s = constant_op.constant(0) + control_flow.while_stmt( + test=lambda: math_ops.equal(s, 0), + body=body, + get_state=lambda: (i, s), + set_state=set_state, + symbol_names=('i', 's'), + opts={}) + + self.assertEqual(i, 2) + self.assertEqual(s, 32) + self.assertOpCreated('StatelessWhile') + # Check that the temporary staging of the body did not create extra ops. + # Node naming is inconsistent between V1 and V2. + self.assertGraphContains(r'(while/)?pow$', 1) def test_tensor_with_side_effecting_condition(self): v = self.variable('v', 0, dtypes.int32) @@ -589,6 +653,7 @@ class WhileLoopTest(testing.AutoGraphTestCase): self.assertEqual(i, (5,)) self.assertEqual(v, (12345,)) + self.assertOpCreated('While') def test_tensor_with_python_state(self): class MutableObject(object): @@ -613,8 +678,10 @@ class WhileLoopTest(testing.AutoGraphTestCase): set_state=set_state, symbol_names=('i', 'state.field'), opts={}) + self.assertEqual(i, 5) self.assertEqual(state.field, 1234) + self.assertOpCreated('StatelessWhile') def test_python(self): def body(): @@ -632,7 +699,9 @@ class WhileLoopTest(testing.AutoGraphTestCase): set_state=None, symbol_names=('i', 's'), opts={}) + self.assertEqual(s, 1234) + self.assertNoOpsCreated() def test_python_with_tensor_state(self): def body(): @@ -650,8 +719,10 @@ class WhileLoopTest(testing.AutoGraphTestCase): set_state=None, symbol_names=('i', 's'), opts={}) + self.assertEqual(i, 5) self.assertEqual(s, 1234) + self.assertOpsNotCreated(('While', 'StatelessWhile')) def test_python_while_infinite(self): if not __debug__: @@ -732,6 +803,7 @@ class WhileLoopTest(testing.AutoGraphTestCase): r'.* Large unrolled loop.*Add.*', out_capturer.getvalue())) def _basic_loop(self, init_value, body_fn): + def body(): nonlocal i, s s = body_fn(i, s) @@ -802,6 +874,7 @@ class IfStmtTest(testing.AutoGraphTestCase): self.assertEqual(test_fn(constant_op.constant(True)), 1) self.assertEqual(test_fn(constant_op.constant(False)), -1) + self.assertOpCreated('StatelessIf') def test_tensor_no_outputs(self): @@ -831,6 +904,7 @@ class IfStmtTest(testing.AutoGraphTestCase): self.assertIsNone(test_fn(constant_op.constant(True))) self.assertIsNone(test_fn(constant_op.constant(False))) + self.assertOpCreated('StatelessIf') def test_tensor_multiple_returns(self): @@ -862,6 +936,7 @@ class IfStmtTest(testing.AutoGraphTestCase): self.assertEqual(test_fn(constant_op.constant(True)), (1, 2)) self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2)) + self.assertOpCreated('StatelessIf') def test_python(self): @@ -887,6 +962,7 @@ class IfStmtTest(testing.AutoGraphTestCase): self.assertEqual(test_fn(True), 1) self.assertEqual(test_fn(False), -1) + self.assertNoOpsCreated() def test_python_multiple_returns(self): @@ -914,6 +990,7 @@ class IfStmtTest(testing.AutoGraphTestCase): self.assertEqual(test_fn(True), (1, 2)) self.assertEqual(test_fn(False), (-1, -2)) + self.assertNoOpsCreated() def _basic_cond(self, body_fn, else_fn): def body(): diff --git a/tensorflow/python/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py index f4238bea397..1da82db66c8 100644 --- a/tensorflow/python/autograph/utils/testing.py +++ b/tensorflow/python/autograph/utils/testing.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import re import types import unittest from tensorflow.python.eager import def_function +from tensorflow.python.framework import op_callbacks from tensorflow.python.framework import ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -50,6 +53,10 @@ class AutoGraphTestCase(test.TestCase): baz_actual, value_actual = test_fn() self.assertEqual(baz_actual, value_actual) + + Only assertions that require evaluation outside the function are lifted + outside the function scope. The rest execute inline, at function creation + time. """ def __new__(cls, *args): @@ -65,18 +72,31 @@ class AutoGraphTestCase(test.TestCase): return obj + def _op_callback( + self, op_type, inputs, attrs, outputs, op_name=None, graph=None): + self.trace_log.append(op_type) + def _run_as_tf_function(self, fn): def wrapper(self): @def_function.function(autograph=False) # Testing autograph itself. def fn_wrapper(): self.assertions = [] + self.graph_assertions = [] + self.trace_log = [] fn() targets = [args for _, args in self.assertions] return targets - actuals = self.evaluate(fn_wrapper()) - for (_, args), value in zip(self.assertions, actuals): - args[:] = value + + tensors = fn_wrapper() + + for assertion in self.graph_assertions: + assertion(fn_wrapper.get_concrete_function().graph) + + actuals = self.evaluate(tensors) + for (assertion, _), values in zip(self.assertions, actuals): + assertion(*values) + return wrapper def variable(self, name, value, dtype): @@ -88,12 +108,39 @@ class AutoGraphTestCase(test.TestCase): def setUp(self): super().setUp() + os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1' self.variables = {} + self.trace_log = [] + op_callbacks.add_op_callback(self._op_callback) def tearDown(self): - for fn, args in self.assertions: - fn(*args) + op_callbacks.remove_op_callback(self._op_callback) + self.trace_log = None + self.variables = None super().tearDown() + def assertGraphContains(self, op_regex, n): + def assertion(graph): + matches = [] + for node in graph.as_graph_def().node: + if re.match(op_regex, node.name): + matches.append(node) + for fn in graph.as_graph_def().library.function: + for node_def in fn.node_def: + if re.match(op_regex, node_def.name): + matches.append(node_def) + self.assertLen(matches, n) + + self.graph_assertions.append(assertion) + + def assertOpCreated(self, op_type): + self.assertIn(op_type, self.trace_log) + + def assertOpsNotCreated(self, op_types): + self.assertEmpty(set(op_types) & set(self.trace_log)) + + def assertNoOpsCreated(self): + self.assertEmpty(self.trace_log) + def assertEqual(self, *args): self.assertions.append((super().assertEqual, list(args)))