Allow creating symbols in TF loops. This is on a best-effort basis, and only works in cases when the symbol being created doesn't depend on previous iterations.

This change only adds the feature, but does not enable it yet. It will be enabled separately.

PiperOrigin-RevId: 325282159
Change-Id: I29fd9792454c6dac4189d0756d517f9ff0390700
This commit is contained in:
Dan Moldovan 2020-08-06 12:13:57 -07:00 committed by TensorFlower Gardener
parent b9a5452924
commit c0374a95df
5 changed files with 387 additions and 26 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import contextlib
import imp
import inspect
import os
import sys
import six
@ -100,6 +101,7 @@ class TestCase(test.TestCase):
def setUp(self):
# AutoGraph tests must run in graph mode to properly test control flow.
os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1'
self.graph = ops.Graph().as_default()
self.graph.__enter__()

View File

@ -66,22 +66,48 @@ else:
pass
```
Similarly, variables may not be defined inside a TensorFlow loop, unless they
are local to the loop. A variable is local to the loop if (1) it's not used
after the loop and (2) the value from a previour iteration is not used in the
next iteration:
Similarly, variables must usually be defined before a TensorFlow loop.
The most common example that is not allowed is a loop which initializes some
accumulator variable in the first iteration:
```
del x
while tf.random.uniform(()) > 0.5: # Error -- x must be defined before the loop
for i in tf.range(100): # Error -- x must be defined before the loop
if i == 0:
x = tf.constant(1)
else:
x = x + 1
tf.print(x)
```
When the variable is only used inside the loop and does not depend on previous
iterations, then it's ok to only be initialized inside the loop.
```
del x
while tf.random.uniform(()) > 0.5: # Okay -- x is not used after the loop
x = tf.constant(1)
```
* New in TF 2.4 *
As long as it doesn't depend on previous iterations, the variable may also be
used after the loop, however in that case the loop must execute at least one
iteration, and will raise a runtime error otherwise.
```
del x
for i in tf.range(10): # Okay -- x does not depend on previous iterations
x = tf.constant(1)
tf.print(x)
```
```
del x
while tf.random.uniform(()) > 0.5: # Okay -- x is local to the loop
while tf.constant(False): # Error -- loop must initialize x!
x = tf.constant(1)
tf.print(x)
```
Avoid these limitations by defining a default value before the control flow
@ -98,6 +124,34 @@ Note: `None` values and undefined symbols are allowed in Eager control flow,
because Eager execution uses Python control flow, rather than TensorFlow
control flow ops.
#### Special case: creating Tensors in a loop
* New in TF 2.4 *
A very common use-case is to run a training loop that creates some outputs:
```
for i in tf.range(num_steps):
outputs = train(next(data_iterator))
```
Often times these outputs can be nested structures of Tensors, which makes them
impractical to initialize ahead of the loop.
To help with this use-case, AutoGraph lets you run such loops, under certain
conditions:
* outputs must be a Tensor, Python numeric, or a structure of these
* outputs must not depend on the value from a previous iteration; in other
words, the outputs may only appear to the left of an assignment operation
* the loop must run at least one iteration
If the type of outputs is not recognized, then the usual
"outputs must be defined before the loop" is raised at graph construction.
AutoGraph also inserts a `tf.Assert` statement that raises a runtime error
if the loop did not execute at least one iteration.
### Indirect modifications and hidden side effects in TensorFlow control flow
Key Point: We recommend using a functional programming style, immutable Python

View File

@ -60,6 +60,7 @@ from __future__ import division
from __future__ import print_function
import functools
import os
import traceback
import numpy as np
@ -79,6 +80,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
@ -99,19 +101,70 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
# datasets. Before it can be used though, we need to standardize the interface.
def _verify_loop_init_vars(values, symbol_names):
"""Ensures that all values in the state are defined when entering a loop."""
for name, value in zip(symbol_names, values):
if value is None:
raise ValueError("'{}' may not be None before the loop.".format(name))
if isinstance(value, variables.UndefinedReturnValue):
# Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs.
raise ValueError(
'return statements are not supported within a TensorFlow loop.')
if isinstance(value, variables.Undefined):
raise ValueError("'{}' must be defined before the loop.".format(name))
def _is_none_or_undef(value):
"""Tests whether a value is None or undefined.
AutoGraph represents undefined symbols using special objects of type Undefined
or UndefinedReturnValue.
Args:
value: value to test
Returns:
Boolean
"""
return ((value is None)
or isinstance(value, variables.UndefinedReturnValue)
or isinstance(value, variables.Undefined))
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
"""Ensures that all values in the state are valid to use in a TF loop.
The init_vars may contain placeholder values derived from first_iter_vars.
Args:
init_vars: initial loop variables (as taken before entering the loop)
symbol_names: corresponding names of the initial loop variables
first_iter_vars: loop variables after one iteration of the loop
"""
if not symbol_names:
return
if first_iter_vars is None:
first_iter_vars = (None,) * len(symbol_names)
assert len(symbol_names) == len(init_vars)
assert len(symbol_names) == len(first_iter_vars)
for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars):
if isinstance(val, variables.UndefinedReturnValue):
if fi_val:
raise ValueError(
'the return value from a TensorFlow loop may only be a {}; got {}'
.format(LEGAL_LOOP_TYPES, type(fi_val)))
else:
# TODO(mdan): This can be handled by removing the return value.
raise NotImplementedError(
'a return statement cannot be placed inside this TensorFlow loop;'
' this may happen if a return statement depends on a'
' static Python condition such as a hyperparameter')
error_msg = None
if val is None:
error_msg = "'{}' may not be None before the loop".format(name)
elif isinstance(val, variables.Undefined):
error_msg = "'{}' must be defined before the loop".format(name)
# This only happens when we could not infer a placeholder for the
# variable. The canonical case when that happens is when _placeholder_value
# couldnot infer a placeholder for it. That means it's of an unknown type
# or it's still undefined after staging one iteration.
if error_msg is not None:
if fi_val:
error_msg += (", unless it's a {}; got {}".format(
LEGAL_LOOP_TYPES, type(fi_val)))
else:
# TODO(mdan): This can be handled by removing the loop var.
error_msg += '.'
raise ValueError(error_msg)
def _is_subshape(left, right):
@ -876,21 +929,134 @@ def _shape_invariants_mapping_to_positional_list(mapping, keys):
return tuple(result)
# Textual description of what a legal TF loop variable is. This description
# summarizes types that _placeholder_value below can handle. Keep the two
# together and in sync.
LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof'
def _placeholder_value(like, original):
if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
return original
if isinstance(like, (int, float, bool)):
return type(like)(0)
if tensor_util.is_tensor(like):
return array_ops.zeros(like.shape, like.dtype)
elif isinstance(like, (list, tuple, dict)):
return nest.map_structure(_placeholder_value, like)
return original
def _try_handling_undefineds(
body, get_state, set_state, init_vars, nulls, symbol_names):
"""Makes a best-effort attempt to substitute undefineds with placeholders.
Note: this substitution requires two things to happen:
1. the types of loop variables could be inferred (usually by staging one
iteration)
2. these types could be replaced by placeholders (e.g. zero values, for
tensors.
Args:
body: a function representing the loop body. See while_stmt.
get_state: state getter for the loop statement. See while_stmt.
set_state: state getter for the loop statement. See while_stmt.
init_vars: loop variables before entering the loop. See while_stmt.
nulls: list of boolean flags indicating whether the corresponding loop
var is None or undefined.
symbol_names: list of loop variable names. See while_stmt.
Returns:
A tuple (success, new_init_vars). success is a boolean flag indicating
whether types could be successfully inferred (step 1 above). new_init_vars
contains the loop vars, with None or undefined values replaced by
placeholders, where possible (step 2 above).
"""
state_modified = False
if not os.getenv('AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS', ''):
_verify_loop_init_vars(init_vars, symbol_names)
return False, init_vars
try:
# Stage an iteration of the loop body in a temporary graph.
with func_graph.FuncGraph('tmp').as_default():
# This call to set_state helps report nicer error messages when symbols
# are inconsistently used.
set_state(init_vars)
state_modified = True
body()
first_iter_vars = get_state()
except (UnboundLocalError, TypeError, ValueError, KeyError):
# Fall back to the old functionality. It will likely result in an input
# validation failure.
first_iter_vars = None
finally:
if state_modified:
set_state(init_vars)
if first_iter_vars is not None:
# Note: the actual placeholder value doesn't matter, because as the staging
# proved, it will be replaced by an actual value before being read.
init_vars = tuple(
(_placeholder_value(iv, v) if n else v)
for v, n, iv in zip(init_vars, nulls, first_iter_vars))
success = True
else:
success = False
# This check runs regardless, in case we captured non-Tensor inputs.
_verify_loop_init_vars(init_vars, symbol_names, first_iter_vars)
return success, init_vars
def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars):
"""Creates an error message asking for the loop to iterate at least once."""
var_names = []
for sn, n, v in zip(symbol_names, nulls, init_vars):
if not n:
continue
if isinstance(v, variables.UndefinedReturnValue):
var_names.append('the function return value')
else:
var_names.append(sn)
var_names = ', '.join(var_names)
return 'loop must iterate at least once to initialize {}'.format(var_names)
def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
"""Overload of while_stmt that stages a TF while_stmt."""
init_vars = get_state()
_verify_loop_init_vars(init_vars, symbol_names)
orig_init_vars = init_vars
nulls = tuple(_is_none_or_undef(v) for v in init_vars)
if any(nulls):
require_one_iteration, init_vars = _try_handling_undefineds(
body, get_state, set_state, init_vars, nulls, symbol_names)
else:
require_one_iteration = False
def aug_test(*loop_vars):
if require_one_iteration:
loop_vars = loop_vars[1:]
set_state(loop_vars)
return test()
def aug_body(*loop_vars):
if require_one_iteration:
loop_vars = loop_vars[1:]
set_state(loop_vars)
body()
new_loop_vars = get_state()
_verify_tf_loop_vars(
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
if require_one_iteration:
new_loop_vars = (True,) + new_loop_vars
return new_loop_vars
if 'shape_invariants' in opts:
@ -904,8 +1070,23 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
# This enforces consistency across versions.
while_loop_opts['return_same_structure'] = True
if require_one_iteration:
aug_init_vars = (False,) + init_vars
else:
aug_init_vars = init_vars
final_loop_vars = control_flow_ops.while_loop(
aug_test, aug_body, init_vars, **while_loop_opts)
aug_test, aug_body, aug_init_vars, **while_loop_opts)
if require_one_iteration:
with ops.control_dependencies([
control_flow_ops.Assert(final_loop_vars[0], [
_runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars)
])
]):
final_loop_vars = tuple(
array_ops.identity(v) for v in final_loop_vars[1:])
set_state(final_loop_vars)

View File

@ -86,7 +86,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(s, (1234,))
self.assertOpCreated('StatelessWhile')
def test_range_tensor_explicit_limit_delta(self):
def body(i):
@ -106,7 +108,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(s, (-171207,))
self.assertOpCreated('StatelessWhile')
def test_range_tensor_explicit_limit_negative_delta(self):
def body(i):
@ -126,7 +130,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(s, (171207,))
self.assertOpCreated('StatelessWhile')
def test_range_tensor_random_delta(self):
def body(i):
@ -147,7 +153,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(s, (1234,))
self.assertOpCreated('StatelessWhile')
def test_range_tensor_random_negative_delta(self):
def body(i):
@ -168,7 +176,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(s, (171207,))
self.assertOpCreated('StatelessWhile')
def test_tensor_with_extra_test_object_vars(self):
class MutableObject(object):
@ -194,7 +204,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('state.field_1', 'state.field_2'),
opts={})
self.assertEqual((state.field_1, state.field_2), (6, 6))
self.assertOpCreated('StatelessWhile')
def test_python(self):
def body(i):
@ -214,7 +226,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(s, 1234)
self.assertNoOpsCreated()
def test_python_generator_with_extra_test(self):
def new_generator():
@ -247,6 +261,8 @@ class ForLoopTest(testing.AutoGraphTestCase):
self.assertEqual(next(gen), 4)
self.assertNoOpsCreated()
def test_python_generator_with_extra_test_no_iterations(self):
def new_generator():
for i in range(5):
@ -275,6 +291,8 @@ class ForLoopTest(testing.AutoGraphTestCase):
self.assertEqual(next(gen), 0)
self.assertNoOpsCreated()
def test_tf_dataset(self):
def body(i):
nonlocal s
@ -293,7 +311,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(s, (1234,))
self.assertOpCreated('ScanDataset')
def test_dataset_with_extra_test(self):
def body(i):
@ -313,7 +333,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(s, (12,))
self.assertOpCreated('ScanDataset')
def test_dataset_with_extra_test_collection_vars(self):
def body(i):
@ -335,7 +357,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('l[0]', 's'),
opts={})
self.assertEqual((l[0], s), (3, 3))
self.assertOpCreated('ScanDataset')
def test_dataset_with_extra_test_iteration_limiting(self):
def body(it):
@ -356,7 +380,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('i',),
opts={})
self.assertEqual(i, (3,))
self.assertOpCreated('ScanDataset')
def test_tf_dataset_no_loop_vars(self):
def body(i):
@ -374,6 +400,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={})
self.assertEqual(v.read_value(), 1234)
self.assertOpCreated('ScanDataset')
def test_tf_iterator(self):
def body(i):
@ -395,6 +422,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={})
self.assertEqual(s, 1234)
self.assertOpCreated('IteratorGetNextAsOptional')
def test_tf_iterator_shape_invariants(self):
def body(i):
@ -416,6 +444,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
self.assertAllEqual(s, [0, 1, 2, 3, 4])
self.assertOpCreated('IteratorGetNextAsOptional')
def test_tf_iterator_no_loop_vars(self):
def body(i):
@ -433,6 +462,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={})
self.assertEqual(v.read_value(), 1234)
self.assertOpCreated('IteratorGetNextAsOptional')
def test_tf_ragged_tensor(self):
def body(i):
@ -454,6 +484,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={})
self.assertEqual(s, (123,))
self.assertOpCreated('StatelessWhile')
def test_tf_ragged_tensor_higher_dimensional(self):
def body(i):
@ -479,6 +510,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
opts={})
self.assertEqual(s, (12,))
self.assertOpCreated('StatelessWhile')
def test_tf_ragged_tensor_no_loop_vars(self):
v = self.variable('v', 0, dtypes.int32)
@ -497,6 +529,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
self.assertEqual(v.read_value(), 123)
self.assertOpCreated('While')
def _basic_loop(self, init_value, body_fn):
def body(i):
@ -540,6 +573,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
class WhileLoopTest(testing.AutoGraphTestCase):
def test_tensor(self):
def body():
nonlocal i, s
s = s * 10 + i
@ -559,8 +593,38 @@ class WhileLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('i', 's'),
opts={})
self.assertEqual(i, 5)
self.assertEqual(s, 1234)
self.assertOpCreated('StatelessWhile')
def test_tensor_creating_variable(self):
def body():
nonlocal i, s
i = constant_op.constant(2)
s = i ** 5
def set_state(loop_vars):
nonlocal i, s
i, s = loop_vars
i = variable_operators.Undefined('i')
s = constant_op.constant(0)
control_flow.while_stmt(
test=lambda: math_ops.equal(s, 0),
body=body,
get_state=lambda: (i, s),
set_state=set_state,
symbol_names=('i', 's'),
opts={})
self.assertEqual(i, 2)
self.assertEqual(s, 32)
self.assertOpCreated('StatelessWhile')
# Check that the temporary staging of the body did not create extra ops.
# Node naming is inconsistent between V1 and V2.
self.assertGraphContains(r'(while/)?pow$', 1)
def test_tensor_with_side_effecting_condition(self):
v = self.variable('v', 0, dtypes.int32)
@ -589,6 +653,7 @@ class WhileLoopTest(testing.AutoGraphTestCase):
self.assertEqual(i, (5,))
self.assertEqual(v, (12345,))
self.assertOpCreated('While')
def test_tensor_with_python_state(self):
class MutableObject(object):
@ -613,8 +678,10 @@ class WhileLoopTest(testing.AutoGraphTestCase):
set_state=set_state,
symbol_names=('i', 'state.field'),
opts={})
self.assertEqual(i, 5)
self.assertEqual(state.field, 1234)
self.assertOpCreated('StatelessWhile')
def test_python(self):
def body():
@ -632,7 +699,9 @@ class WhileLoopTest(testing.AutoGraphTestCase):
set_state=None,
symbol_names=('i', 's'),
opts={})
self.assertEqual(s, 1234)
self.assertNoOpsCreated()
def test_python_with_tensor_state(self):
def body():
@ -650,8 +719,10 @@ class WhileLoopTest(testing.AutoGraphTestCase):
set_state=None,
symbol_names=('i', 's'),
opts={})
self.assertEqual(i, 5)
self.assertEqual(s, 1234)
self.assertOpsNotCreated(('While', 'StatelessWhile'))
def test_python_while_infinite(self):
if not __debug__:
@ -732,6 +803,7 @@ class WhileLoopTest(testing.AutoGraphTestCase):
r'.* Large unrolled loop.*Add.*', out_capturer.getvalue()))
def _basic_loop(self, init_value, body_fn):
def body():
nonlocal i, s
s = body_fn(i, s)
@ -802,6 +874,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
self.assertEqual(test_fn(constant_op.constant(True)), 1)
self.assertEqual(test_fn(constant_op.constant(False)), -1)
self.assertOpCreated('StatelessIf')
def test_tensor_no_outputs(self):
@ -831,6 +904,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
self.assertIsNone(test_fn(constant_op.constant(True)))
self.assertIsNone(test_fn(constant_op.constant(False)))
self.assertOpCreated('StatelessIf')
def test_tensor_multiple_returns(self):
@ -862,6 +936,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
self.assertOpCreated('StatelessIf')
def test_python(self):
@ -887,6 +962,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
self.assertEqual(test_fn(True), 1)
self.assertEqual(test_fn(False), -1)
self.assertNoOpsCreated()
def test_python_multiple_returns(self):
@ -914,6 +990,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
self.assertEqual(test_fn(True), (1, 2))
self.assertEqual(test_fn(False), (-1, -2))
self.assertNoOpsCreated()
def _basic_cond(self, body_fn, else_fn):
def body():

View File

@ -18,10 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import types
import unittest
from tensorflow.python.eager import def_function
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -50,6 +53,10 @@ class AutoGraphTestCase(test.TestCase):
baz_actual, value_actual = test_fn()
self.assertEqual(baz_actual, value_actual)
Only assertions that require evaluation outside the function are lifted
outside the function scope. The rest execute inline, at function creation
time.
"""
def __new__(cls, *args):
@ -65,18 +72,31 @@ class AutoGraphTestCase(test.TestCase):
return obj
def _op_callback(
self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
self.trace_log.append(op_type)
def _run_as_tf_function(self, fn):
def wrapper(self):
@def_function.function(autograph=False) # Testing autograph itself.
def fn_wrapper():
self.assertions = []
self.graph_assertions = []
self.trace_log = []
fn()
targets = [args for _, args in self.assertions]
return targets
actuals = self.evaluate(fn_wrapper())
for (_, args), value in zip(self.assertions, actuals):
args[:] = value
tensors = fn_wrapper()
for assertion in self.graph_assertions:
assertion(fn_wrapper.get_concrete_function().graph)
actuals = self.evaluate(tensors)
for (assertion, _), values in zip(self.assertions, actuals):
assertion(*values)
return wrapper
def variable(self, name, value, dtype):
@ -88,12 +108,39 @@ class AutoGraphTestCase(test.TestCase):
def setUp(self):
super().setUp()
os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1'
self.variables = {}
self.trace_log = []
op_callbacks.add_op_callback(self._op_callback)
def tearDown(self):
for fn, args in self.assertions:
fn(*args)
op_callbacks.remove_op_callback(self._op_callback)
self.trace_log = None
self.variables = None
super().tearDown()
def assertGraphContains(self, op_regex, n):
def assertion(graph):
matches = []
for node in graph.as_graph_def().node:
if re.match(op_regex, node.name):
matches.append(node)
for fn in graph.as_graph_def().library.function:
for node_def in fn.node_def:
if re.match(op_regex, node_def.name):
matches.append(node_def)
self.assertLen(matches, n)
self.graph_assertions.append(assertion)
def assertOpCreated(self, op_type):
self.assertIn(op_type, self.trace_log)
def assertOpsNotCreated(self, op_types):
self.assertEmpty(set(op_types) & set(self.trace_log))
def assertNoOpsCreated(self):
self.assertEmpty(self.trace_log)
def assertEqual(self, *args):
self.assertions.append((super().assertEqual, list(args)))