Allow creating symbols in TF loops. This is on a best-effort basis, and only works in cases when the symbol being created doesn't depend on previous iterations.
This change only adds the feature, but does not enable it yet. It will be enabled separately. PiperOrigin-RevId: 325282159 Change-Id: I29fd9792454c6dac4189d0756d517f9ff0390700
This commit is contained in:
parent
b9a5452924
commit
c0374a95df
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import contextlib
|
||||
import imp
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
@ -100,6 +101,7 @@ class TestCase(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# AutoGraph tests must run in graph mode to properly test control flow.
|
||||
os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1'
|
||||
self.graph = ops.Graph().as_default()
|
||||
self.graph.__enter__()
|
||||
|
||||
|
@ -66,22 +66,48 @@ else:
|
||||
pass
|
||||
```
|
||||
|
||||
Similarly, variables may not be defined inside a TensorFlow loop, unless they
|
||||
are local to the loop. A variable is local to the loop if (1) it's not used
|
||||
after the loop and (2) the value from a previour iteration is not used in the
|
||||
next iteration:
|
||||
Similarly, variables must usually be defined before a TensorFlow loop.
|
||||
|
||||
The most common example that is not allowed is a loop which initializes some
|
||||
accumulator variable in the first iteration:
|
||||
|
||||
```
|
||||
del x
|
||||
while tf.random.uniform(()) > 0.5: # Error -- x must be defined before the loop
|
||||
for i in tf.range(100): # Error -- x must be defined before the loop
|
||||
if i == 0:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
x = x + 1
|
||||
tf.print(x)
|
||||
```
|
||||
|
||||
When the variable is only used inside the loop and does not depend on previous
|
||||
iterations, then it's ok to only be initialized inside the loop.
|
||||
|
||||
```
|
||||
del x
|
||||
while tf.random.uniform(()) > 0.5: # Okay -- x is not used after the loop
|
||||
x = tf.constant(1)
|
||||
```
|
||||
|
||||
* New in TF 2.4 *
|
||||
|
||||
As long as it doesn't depend on previous iterations, the variable may also be
|
||||
used after the loop, however in that case the loop must execute at least one
|
||||
iteration, and will raise a runtime error otherwise.
|
||||
|
||||
```
|
||||
del x
|
||||
for i in tf.range(10): # Okay -- x does not depend on previous iterations
|
||||
x = tf.constant(1)
|
||||
tf.print(x)
|
||||
```
|
||||
|
||||
```
|
||||
del x
|
||||
while tf.random.uniform(()) > 0.5: # Okay -- x is local to the loop
|
||||
while tf.constant(False): # Error -- loop must initialize x!
|
||||
x = tf.constant(1)
|
||||
tf.print(x)
|
||||
```
|
||||
|
||||
Avoid these limitations by defining a default value before the control flow
|
||||
@ -98,6 +124,34 @@ Note: `None` values and undefined symbols are allowed in Eager control flow,
|
||||
because Eager execution uses Python control flow, rather than TensorFlow
|
||||
control flow ops.
|
||||
|
||||
#### Special case: creating Tensors in a loop
|
||||
|
||||
* New in TF 2.4 *
|
||||
|
||||
A very common use-case is to run a training loop that creates some outputs:
|
||||
|
||||
```
|
||||
for i in tf.range(num_steps):
|
||||
outputs = train(next(data_iterator))
|
||||
```
|
||||
|
||||
Often times these outputs can be nested structures of Tensors, which makes them
|
||||
impractical to initialize ahead of the loop.
|
||||
|
||||
To help with this use-case, AutoGraph lets you run such loops, under certain
|
||||
conditions:
|
||||
|
||||
* outputs must be a Tensor, Python numeric, or a structure of these
|
||||
* outputs must not depend on the value from a previous iteration; in other
|
||||
words, the outputs may only appear to the left of an assignment operation
|
||||
* the loop must run at least one iteration
|
||||
|
||||
If the type of outputs is not recognized, then the usual
|
||||
"outputs must be defined before the loop" is raised at graph construction.
|
||||
|
||||
AutoGraph also inserts a `tf.Assert` statement that raises a runtime error
|
||||
if the loop did not execute at least one iteration.
|
||||
|
||||
### Indirect modifications and hidden side effects in TensorFlow control flow
|
||||
|
||||
Key Point: We recommend using a functional programming style, immutable Python
|
||||
|
@ -60,6 +60,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
@ -79,6 +80,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -99,19 +101,70 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
|
||||
# datasets. Before it can be used though, we need to standardize the interface.
|
||||
|
||||
|
||||
def _verify_loop_init_vars(values, symbol_names):
|
||||
"""Ensures that all values in the state are defined when entering a loop."""
|
||||
for name, value in zip(symbol_names, values):
|
||||
if value is None:
|
||||
raise ValueError("'{}' may not be None before the loop.".format(name))
|
||||
if isinstance(value, variables.UndefinedReturnValue):
|
||||
# Assumption: the loop will only capture the variable which tracks the
|
||||
# return value if the loop contained a return statement.
|
||||
# TODO(mdan): This should be checked at the place where return occurs.
|
||||
raise ValueError(
|
||||
'return statements are not supported within a TensorFlow loop.')
|
||||
if isinstance(value, variables.Undefined):
|
||||
raise ValueError("'{}' must be defined before the loop.".format(name))
|
||||
def _is_none_or_undef(value):
|
||||
"""Tests whether a value is None or undefined.
|
||||
|
||||
AutoGraph represents undefined symbols using special objects of type Undefined
|
||||
or UndefinedReturnValue.
|
||||
|
||||
Args:
|
||||
value: value to test
|
||||
Returns:
|
||||
Boolean
|
||||
"""
|
||||
return ((value is None)
|
||||
or isinstance(value, variables.UndefinedReturnValue)
|
||||
or isinstance(value, variables.Undefined))
|
||||
|
||||
|
||||
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
|
||||
"""Ensures that all values in the state are valid to use in a TF loop.
|
||||
|
||||
The init_vars may contain placeholder values derived from first_iter_vars.
|
||||
|
||||
Args:
|
||||
init_vars: initial loop variables (as taken before entering the loop)
|
||||
symbol_names: corresponding names of the initial loop variables
|
||||
first_iter_vars: loop variables after one iteration of the loop
|
||||
"""
|
||||
if not symbol_names:
|
||||
return
|
||||
if first_iter_vars is None:
|
||||
first_iter_vars = (None,) * len(symbol_names)
|
||||
|
||||
assert len(symbol_names) == len(init_vars)
|
||||
assert len(symbol_names) == len(first_iter_vars)
|
||||
for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars):
|
||||
if isinstance(val, variables.UndefinedReturnValue):
|
||||
if fi_val:
|
||||
raise ValueError(
|
||||
'the return value from a TensorFlow loop may only be a {}; got {}'
|
||||
.format(LEGAL_LOOP_TYPES, type(fi_val)))
|
||||
else:
|
||||
# TODO(mdan): This can be handled by removing the return value.
|
||||
raise NotImplementedError(
|
||||
'a return statement cannot be placed inside this TensorFlow loop;'
|
||||
' this may happen if a return statement depends on a'
|
||||
' static Python condition such as a hyperparameter')
|
||||
|
||||
error_msg = None
|
||||
if val is None:
|
||||
error_msg = "'{}' may not be None before the loop".format(name)
|
||||
elif isinstance(val, variables.Undefined):
|
||||
error_msg = "'{}' must be defined before the loop".format(name)
|
||||
|
||||
# This only happens when we could not infer a placeholder for the
|
||||
# variable. The canonical case when that happens is when _placeholder_value
|
||||
# couldnot infer a placeholder for it. That means it's of an unknown type
|
||||
# or it's still undefined after staging one iteration.
|
||||
if error_msg is not None:
|
||||
if fi_val:
|
||||
error_msg += (", unless it's a {}; got {}".format(
|
||||
LEGAL_LOOP_TYPES, type(fi_val)))
|
||||
else:
|
||||
# TODO(mdan): This can be handled by removing the loop var.
|
||||
error_msg += '.'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def _is_subshape(left, right):
|
||||
@ -876,21 +929,134 @@ def _shape_invariants_mapping_to_positional_list(mapping, keys):
|
||||
return tuple(result)
|
||||
|
||||
|
||||
# Textual description of what a legal TF loop variable is. This description
|
||||
# summarizes types that _placeholder_value below can handle. Keep the two
|
||||
# together and in sync.
|
||||
LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof'
|
||||
|
||||
|
||||
def _placeholder_value(like, original):
|
||||
if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
|
||||
return original
|
||||
if isinstance(like, (int, float, bool)):
|
||||
return type(like)(0)
|
||||
if tensor_util.is_tensor(like):
|
||||
return array_ops.zeros(like.shape, like.dtype)
|
||||
elif isinstance(like, (list, tuple, dict)):
|
||||
return nest.map_structure(_placeholder_value, like)
|
||||
return original
|
||||
|
||||
|
||||
def _try_handling_undefineds(
|
||||
body, get_state, set_state, init_vars, nulls, symbol_names):
|
||||
"""Makes a best-effort attempt to substitute undefineds with placeholders.
|
||||
|
||||
Note: this substitution requires two things to happen:
|
||||
1. the types of loop variables could be inferred (usually by staging one
|
||||
iteration)
|
||||
2. these types could be replaced by placeholders (e.g. zero values, for
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
body: a function representing the loop body. See while_stmt.
|
||||
get_state: state getter for the loop statement. See while_stmt.
|
||||
set_state: state getter for the loop statement. See while_stmt.
|
||||
init_vars: loop variables before entering the loop. See while_stmt.
|
||||
nulls: list of boolean flags indicating whether the corresponding loop
|
||||
var is None or undefined.
|
||||
symbol_names: list of loop variable names. See while_stmt.
|
||||
Returns:
|
||||
A tuple (success, new_init_vars). success is a boolean flag indicating
|
||||
whether types could be successfully inferred (step 1 above). new_init_vars
|
||||
contains the loop vars, with None or undefined values replaced by
|
||||
placeholders, where possible (step 2 above).
|
||||
"""
|
||||
state_modified = False
|
||||
|
||||
if not os.getenv('AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS', ''):
|
||||
_verify_loop_init_vars(init_vars, symbol_names)
|
||||
return False, init_vars
|
||||
|
||||
try:
|
||||
# Stage an iteration of the loop body in a temporary graph.
|
||||
with func_graph.FuncGraph('tmp').as_default():
|
||||
# This call to set_state helps report nicer error messages when symbols
|
||||
# are inconsistently used.
|
||||
set_state(init_vars)
|
||||
state_modified = True
|
||||
|
||||
body()
|
||||
first_iter_vars = get_state()
|
||||
except (UnboundLocalError, TypeError, ValueError, KeyError):
|
||||
# Fall back to the old functionality. It will likely result in an input
|
||||
# validation failure.
|
||||
first_iter_vars = None
|
||||
finally:
|
||||
if state_modified:
|
||||
set_state(init_vars)
|
||||
|
||||
if first_iter_vars is not None:
|
||||
# Note: the actual placeholder value doesn't matter, because as the staging
|
||||
# proved, it will be replaced by an actual value before being read.
|
||||
init_vars = tuple(
|
||||
(_placeholder_value(iv, v) if n else v)
|
||||
for v, n, iv in zip(init_vars, nulls, first_iter_vars))
|
||||
success = True
|
||||
else:
|
||||
success = False
|
||||
|
||||
# This check runs regardless, in case we captured non-Tensor inputs.
|
||||
_verify_loop_init_vars(init_vars, symbol_names, first_iter_vars)
|
||||
|
||||
return success, init_vars
|
||||
|
||||
|
||||
def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars):
|
||||
"""Creates an error message asking for the loop to iterate at least once."""
|
||||
var_names = []
|
||||
for sn, n, v in zip(symbol_names, nulls, init_vars):
|
||||
if not n:
|
||||
continue
|
||||
if isinstance(v, variables.UndefinedReturnValue):
|
||||
var_names.append('the function return value')
|
||||
else:
|
||||
var_names.append(sn)
|
||||
var_names = ', '.join(var_names)
|
||||
return 'loop must iterate at least once to initialize {}'.format(var_names)
|
||||
|
||||
|
||||
def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
"""Overload of while_stmt that stages a TF while_stmt."""
|
||||
init_vars = get_state()
|
||||
_verify_loop_init_vars(init_vars, symbol_names)
|
||||
orig_init_vars = init_vars
|
||||
|
||||
nulls = tuple(_is_none_or_undef(v) for v in init_vars)
|
||||
if any(nulls):
|
||||
require_one_iteration, init_vars = _try_handling_undefineds(
|
||||
body, get_state, set_state, init_vars, nulls, symbol_names)
|
||||
else:
|
||||
require_one_iteration = False
|
||||
|
||||
def aug_test(*loop_vars):
|
||||
if require_one_iteration:
|
||||
loop_vars = loop_vars[1:]
|
||||
|
||||
set_state(loop_vars)
|
||||
return test()
|
||||
|
||||
def aug_body(*loop_vars):
|
||||
if require_one_iteration:
|
||||
loop_vars = loop_vars[1:]
|
||||
|
||||
set_state(loop_vars)
|
||||
body()
|
||||
new_loop_vars = get_state()
|
||||
_verify_tf_loop_vars(
|
||||
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
|
||||
|
||||
if require_one_iteration:
|
||||
new_loop_vars = (True,) + new_loop_vars
|
||||
|
||||
return new_loop_vars
|
||||
|
||||
if 'shape_invariants' in opts:
|
||||
@ -904,8 +1070,23 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
# This enforces consistency across versions.
|
||||
while_loop_opts['return_same_structure'] = True
|
||||
|
||||
if require_one_iteration:
|
||||
aug_init_vars = (False,) + init_vars
|
||||
else:
|
||||
aug_init_vars = init_vars
|
||||
|
||||
final_loop_vars = control_flow_ops.while_loop(
|
||||
aug_test, aug_body, init_vars, **while_loop_opts)
|
||||
aug_test, aug_body, aug_init_vars, **while_loop_opts)
|
||||
|
||||
if require_one_iteration:
|
||||
with ops.control_dependencies([
|
||||
control_flow_ops.Assert(final_loop_vars[0], [
|
||||
_runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars)
|
||||
])
|
||||
]):
|
||||
final_loop_vars = tuple(
|
||||
array_ops.identity(v) for v in final_loop_vars[1:])
|
||||
|
||||
set_state(final_loop_vars)
|
||||
|
||||
|
||||
|
@ -86,7 +86,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
|
||||
self.assertEqual(s, (1234,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_range_tensor_explicit_limit_delta(self):
|
||||
def body(i):
|
||||
@ -106,7 +108,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
|
||||
self.assertEqual(s, (-171207,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_range_tensor_explicit_limit_negative_delta(self):
|
||||
def body(i):
|
||||
@ -126,7 +130,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
|
||||
self.assertEqual(s, (171207,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_range_tensor_random_delta(self):
|
||||
def body(i):
|
||||
@ -147,7 +153,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
|
||||
self.assertEqual(s, (1234,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_range_tensor_random_negative_delta(self):
|
||||
def body(i):
|
||||
@ -168,7 +176,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={'iterate_names': 'i'})
|
||||
|
||||
self.assertEqual(s, (171207,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_tensor_with_extra_test_object_vars(self):
|
||||
class MutableObject(object):
|
||||
@ -194,7 +204,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('state.field_1', 'state.field_2'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual((state.field_1, state.field_2), (6, 6))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_python(self):
|
||||
def body(i):
|
||||
@ -214,7 +226,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, 1234)
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def test_python_generator_with_extra_test(self):
|
||||
def new_generator():
|
||||
@ -247,6 +261,8 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(next(gen), 4)
|
||||
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def test_python_generator_with_extra_test_no_iterations(self):
|
||||
def new_generator():
|
||||
for i in range(5):
|
||||
@ -275,6 +291,8 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(next(gen), 0)
|
||||
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def test_tf_dataset(self):
|
||||
def body(i):
|
||||
nonlocal s
|
||||
@ -293,7 +311,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, (1234,))
|
||||
self.assertOpCreated('ScanDataset')
|
||||
|
||||
def test_dataset_with_extra_test(self):
|
||||
def body(i):
|
||||
@ -313,7 +333,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, (12,))
|
||||
self.assertOpCreated('ScanDataset')
|
||||
|
||||
def test_dataset_with_extra_test_collection_vars(self):
|
||||
def body(i):
|
||||
@ -335,7 +357,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('l[0]', 's'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual((l[0], s), (3, 3))
|
||||
self.assertOpCreated('ScanDataset')
|
||||
|
||||
def test_dataset_with_extra_test_iteration_limiting(self):
|
||||
def body(it):
|
||||
@ -356,7 +380,9 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(i, (3,))
|
||||
self.assertOpCreated('ScanDataset')
|
||||
|
||||
def test_tf_dataset_no_loop_vars(self):
|
||||
def body(i):
|
||||
@ -374,6 +400,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={})
|
||||
|
||||
self.assertEqual(v.read_value(), 1234)
|
||||
self.assertOpCreated('ScanDataset')
|
||||
|
||||
def test_tf_iterator(self):
|
||||
def body(i):
|
||||
@ -395,6 +422,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, 1234)
|
||||
self.assertOpCreated('IteratorGetNextAsOptional')
|
||||
|
||||
def test_tf_iterator_shape_invariants(self):
|
||||
def body(i):
|
||||
@ -416,6 +444,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
|
||||
|
||||
self.assertAllEqual(s, [0, 1, 2, 3, 4])
|
||||
self.assertOpCreated('IteratorGetNextAsOptional')
|
||||
|
||||
def test_tf_iterator_no_loop_vars(self):
|
||||
def body(i):
|
||||
@ -433,6 +462,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={})
|
||||
|
||||
self.assertEqual(v.read_value(), 1234)
|
||||
self.assertOpCreated('IteratorGetNextAsOptional')
|
||||
|
||||
def test_tf_ragged_tensor(self):
|
||||
def body(i):
|
||||
@ -454,6 +484,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, (123,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_tf_ragged_tensor_higher_dimensional(self):
|
||||
def body(i):
|
||||
@ -479,6 +510,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, (12,))
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_tf_ragged_tensor_no_loop_vars(self):
|
||||
v = self.variable('v', 0, dtypes.int32)
|
||||
@ -497,6 +529,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
|
||||
self.assertEqual(v.read_value(), 123)
|
||||
self.assertOpCreated('While')
|
||||
|
||||
def _basic_loop(self, init_value, body_fn):
|
||||
def body(i):
|
||||
@ -540,6 +573,7 @@ class ForLoopTest(testing.AutoGraphTestCase):
|
||||
class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
|
||||
def body():
|
||||
nonlocal i, s
|
||||
s = s * 10 + i
|
||||
@ -559,8 +593,38 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(s, 1234)
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_tensor_creating_variable(self):
|
||||
|
||||
def body():
|
||||
nonlocal i, s
|
||||
i = constant_op.constant(2)
|
||||
s = i ** 5
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal i, s
|
||||
i, s = loop_vars
|
||||
|
||||
i = variable_operators.Undefined('i')
|
||||
s = constant_op.constant(0)
|
||||
control_flow.while_stmt(
|
||||
test=lambda: math_ops.equal(s, 0),
|
||||
body=body,
|
||||
get_state=lambda: (i, s),
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(i, 2)
|
||||
self.assertEqual(s, 32)
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
# Check that the temporary staging of the body did not create extra ops.
|
||||
# Node naming is inconsistent between V1 and V2.
|
||||
self.assertGraphContains(r'(while/)?pow$', 1)
|
||||
|
||||
def test_tensor_with_side_effecting_condition(self):
|
||||
v = self.variable('v', 0, dtypes.int32)
|
||||
@ -589,6 +653,7 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(i, (5,))
|
||||
self.assertEqual(v, (12345,))
|
||||
self.assertOpCreated('While')
|
||||
|
||||
def test_tensor_with_python_state(self):
|
||||
class MutableObject(object):
|
||||
@ -613,8 +678,10 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 'state.field'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(state.field, 1234)
|
||||
self.assertOpCreated('StatelessWhile')
|
||||
|
||||
def test_python(self):
|
||||
def body():
|
||||
@ -632,7 +699,9 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=None,
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(s, 1234)
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def test_python_with_tensor_state(self):
|
||||
def body():
|
||||
@ -650,8 +719,10 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
set_state=None,
|
||||
symbol_names=('i', 's'),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(i, 5)
|
||||
self.assertEqual(s, 1234)
|
||||
self.assertOpsNotCreated(('While', 'StatelessWhile'))
|
||||
|
||||
def test_python_while_infinite(self):
|
||||
if not __debug__:
|
||||
@ -732,6 +803,7 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
r'.* Large unrolled loop.*Add.*', out_capturer.getvalue()))
|
||||
|
||||
def _basic_loop(self, init_value, body_fn):
|
||||
|
||||
def body():
|
||||
nonlocal i, s
|
||||
s = body_fn(i, s)
|
||||
@ -802,6 +874,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(test_fn(constant_op.constant(True)), 1)
|
||||
self.assertEqual(test_fn(constant_op.constant(False)), -1)
|
||||
self.assertOpCreated('StatelessIf')
|
||||
|
||||
def test_tensor_no_outputs(self):
|
||||
|
||||
@ -831,6 +904,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertIsNone(test_fn(constant_op.constant(True)))
|
||||
self.assertIsNone(test_fn(constant_op.constant(False)))
|
||||
self.assertOpCreated('StatelessIf')
|
||||
|
||||
def test_tensor_multiple_returns(self):
|
||||
|
||||
@ -862,6 +936,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
|
||||
self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
|
||||
self.assertOpCreated('StatelessIf')
|
||||
|
||||
def test_python(self):
|
||||
|
||||
@ -887,6 +962,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(test_fn(True), 1)
|
||||
self.assertEqual(test_fn(False), -1)
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
|
||||
@ -914,6 +990,7 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
self.assertEqual(test_fn(True), (1, 2))
|
||||
self.assertEqual(test_fn(False), (-1, -2))
|
||||
self.assertNoOpsCreated()
|
||||
|
||||
def _basic_cond(self, body_fn, else_fn):
|
||||
def body():
|
||||
|
@ -18,10 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -50,6 +53,10 @@ class AutoGraphTestCase(test.TestCase):
|
||||
|
||||
baz_actual, value_actual = test_fn()
|
||||
self.assertEqual(baz_actual, value_actual)
|
||||
|
||||
Only assertions that require evaluation outside the function are lifted
|
||||
outside the function scope. The rest execute inline, at function creation
|
||||
time.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args):
|
||||
@ -65,18 +72,31 @@ class AutoGraphTestCase(test.TestCase):
|
||||
|
||||
return obj
|
||||
|
||||
def _op_callback(
|
||||
self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
|
||||
self.trace_log.append(op_type)
|
||||
|
||||
def _run_as_tf_function(self, fn):
|
||||
|
||||
def wrapper(self):
|
||||
@def_function.function(autograph=False) # Testing autograph itself.
|
||||
def fn_wrapper():
|
||||
self.assertions = []
|
||||
self.graph_assertions = []
|
||||
self.trace_log = []
|
||||
fn()
|
||||
targets = [args for _, args in self.assertions]
|
||||
return targets
|
||||
actuals = self.evaluate(fn_wrapper())
|
||||
for (_, args), value in zip(self.assertions, actuals):
|
||||
args[:] = value
|
||||
|
||||
tensors = fn_wrapper()
|
||||
|
||||
for assertion in self.graph_assertions:
|
||||
assertion(fn_wrapper.get_concrete_function().graph)
|
||||
|
||||
actuals = self.evaluate(tensors)
|
||||
for (assertion, _), values in zip(self.assertions, actuals):
|
||||
assertion(*values)
|
||||
|
||||
return wrapper
|
||||
|
||||
def variable(self, name, value, dtype):
|
||||
@ -88,12 +108,39 @@ class AutoGraphTestCase(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ['AUTOGRAPH_CREATE_SYMBOLS_IN_LOOPS'] = '1'
|
||||
self.variables = {}
|
||||
self.trace_log = []
|
||||
op_callbacks.add_op_callback(self._op_callback)
|
||||
|
||||
def tearDown(self):
|
||||
for fn, args in self.assertions:
|
||||
fn(*args)
|
||||
op_callbacks.remove_op_callback(self._op_callback)
|
||||
self.trace_log = None
|
||||
self.variables = None
|
||||
super().tearDown()
|
||||
|
||||
def assertGraphContains(self, op_regex, n):
|
||||
def assertion(graph):
|
||||
matches = []
|
||||
for node in graph.as_graph_def().node:
|
||||
if re.match(op_regex, node.name):
|
||||
matches.append(node)
|
||||
for fn in graph.as_graph_def().library.function:
|
||||
for node_def in fn.node_def:
|
||||
if re.match(op_regex, node_def.name):
|
||||
matches.append(node_def)
|
||||
self.assertLen(matches, n)
|
||||
|
||||
self.graph_assertions.append(assertion)
|
||||
|
||||
def assertOpCreated(self, op_type):
|
||||
self.assertIn(op_type, self.trace_log)
|
||||
|
||||
def assertOpsNotCreated(self, op_types):
|
||||
self.assertEmpty(set(op_types) & set(self.trace_log))
|
||||
|
||||
def assertNoOpsCreated(self):
|
||||
self.assertEmpty(self.trace_log)
|
||||
|
||||
def assertEqual(self, *args):
|
||||
self.assertions.append((super().assertEqual, list(args)))
|
||||
|
Loading…
Reference in New Issue
Block a user