Internal cleanup: set up boilerplate for tests which require running inside a tf.function. The boilerplate helps keep the tests clean, although its mechanism may be tricky.

PiperOrigin-RevId: 322444035
Change-Id: I56bbb1e1eb2feb65e16b8171d4a1bd20e1b190a5
This commit is contained in:
Dan Moldovan 2020-07-21 14:33:30 -07:00 committed by TensorFlower Gardener
parent cdc531263a
commit be88c5f8a7
3 changed files with 202 additions and 173 deletions

View File

@ -31,26 +31,22 @@ import six
from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.autograph.operators import variables as variable_operators
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ForLoopTest(test.TestCase):
class ForLoopTest(testing.AutoGraphTestCase):
def test_tensor(self):
def body(i):
@ -70,7 +66,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(s, (1234,))
def test_range_tensor(self):
def body(i):
@ -90,7 +86,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(s, (1234,))
def test_range_tensor_explicit_limit_delta(self):
def body(i):
@ -110,7 +106,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (-171207,))
self.assertEqual(s, (-171207,))
def test_range_tensor_explicit_limit_negative_delta(self):
def body(i):
@ -130,7 +126,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (171207,))
self.assertEqual(s, (171207,))
def test_range_tensor_random_delta(self):
def body(i):
@ -151,7 +147,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(s, (1234,))
def test_range_tensor_random_negative_delta(self):
def body(i):
@ -172,7 +168,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (171207,))
self.assertEqual(s, (171207,))
def test_tensor_with_extra_test_object_vars(self):
class MutableObject(object):
@ -198,7 +194,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('state.field_1', 'state.field_2'),
opts={})
self.assertEqual(self.evaluate((state.field_1, state.field_2)), (6, 6))
self.assertEqual((state.field_1, state.field_2), (6, 6))
def test_python(self):
def body(i):
@ -297,7 +293,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(s, (1234,))
def test_dataset_with_extra_test(self):
def body(i):
@ -317,7 +313,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(self.evaluate(s), (12,))
self.assertEqual(s, (12,))
def test_dataset_with_extra_test_collection_vars(self):
def body(i):
@ -339,7 +335,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('l[0]', 's'),
opts={})
self.assertEqual(self.evaluate((l[0], s)), (3, 3))
self.assertEqual((l[0], s), (3, 3))
def test_dataset_with_extra_test_iteration_limiting(self):
def body(it):
@ -360,100 +356,83 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('i',),
opts={})
self.assertEqual(self.evaluate(i), (3,))
self.assertEqual(i, (3,))
def test_tf_dataset_no_loop_vars(self):
def body(i):
v.assign(v.read_value() * 10 + i)
v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer)
v = self.variable('v', 0, dtypes.int64)
# tf.function required for the automatic control dependencies, and because
# ops test for its presence.
@def_function.function
def test_fn():
control_flow.for_stmt(
dataset_ops.Dataset.range(5),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
control_flow.for_stmt(
dataset_ops.Dataset.range(5),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234)
self.assertEqual(v.read_value(), 1234)
def test_tf_iterator(self):
# graph-mode iterators are only supported inside tf.function.
@def_function.function
def test_fn():
def body(i):
nonlocal s
s = s * 10 + i
def body(i):
nonlocal s
s = s * 10 + i
def set_state(loop_vars):
nonlocal s
s, = loop_vars
def set_state(loop_vars):
nonlocal s
s, = loop_vars
s = constant_op.constant(0, dtype=dtypes.int64)
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (s,),
set_state=set_state,
symbol_names=('s',),
opts={})
return s
self.assertAllEqual(test_fn(), 1234)
s = constant_op.constant(0, dtype=dtypes.int64)
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (s,),
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(s, 1234)
def test_tf_iterator_shape_invariants(self):
# graph-mode iterators are only supported inside tf.function.
@def_function.function
def test_fn():
def body(i):
nonlocal s
s = array_ops.concat([s, [i]], 0)
def body(i):
nonlocal s
s = array_ops.concat([s, [i]], 0)
def set_state(loop_vars):
nonlocal s
s, = loop_vars
def set_state(loop_vars):
nonlocal s
s, = loop_vars
s = constant_op.constant([], dtype=dtypes.int64)
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (s,),
set_state=set_state,
symbol_names=('s',),
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
return s
self.assertAllEqual(test_fn(), [0, 1, 2, 3, 4])
s = constant_op.constant([], dtype=dtypes.int64)
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (s,),
set_state=set_state,
symbol_names=('s',),
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
self.assertAllEqual(s, [0, 1, 2, 3, 4])
def test_tf_iterator_no_loop_vars(self):
def body(i):
v.assign(v.read_value() * 10 + i)
v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer)
v = self.variable('v', 0, dtypes.int64)
# tf.function required for the automatic control dependencies.
@def_function.function
def test_fn():
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234)
self.assertEqual(v.read_value(), 1234)
def test_tf_ragged_tensor(self):
def body(i):
@ -473,7 +452,8 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(self.evaluate(s), (123,))
self.assertEqual(s, (123,))
def test_tf_ragged_tensor_higher_dimensional(self):
def body(i):
@ -497,30 +477,26 @@ class ForLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('s',),
opts={})
self.assertEqual(self.evaluate(s), (12,))
self.assertEqual(s, (12,))
def test_tf_ragged_tensor_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int32)
self.evaluate(v.initializer)
v = self.variable('v', 0, dtypes.int32)
def body(i):
v.assign(v.read_value() * 10 + i[0])
# tf.function required for the automatic control dependencies.
@def_function.function(autograph=False)
def test_fn():
control_flow.for_stmt(
ragged_factory_ops.constant([[1], [2, 4], [3]]),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
control_flow.for_stmt(
ragged_factory_ops.constant([[1], [2, 4], [3]]),
extra_test=None,
body=body,
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn())
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
self.assertEqual(self.evaluate(v.read_value()), 123)
self.assertEqual(v.read_value(), 123)
def _basic_loop(self, init_value, body_fn):
def body(i):
@ -561,8 +537,7 @@ class ForLoopTest(test.TestCase):
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
@test_util.run_all_in_graph_and_eager_modes
class WhileLoopTest(test.TestCase):
class WhileLoopTest(testing.AutoGraphTestCase):
def test_tensor(self):
def body():
@ -584,40 +559,36 @@ class WhileLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('i', 's'),
opts={})
self.assertEqual(self.evaluate((i, s)), (5, 1234))
self.assertEqual(i, 5)
self.assertEqual(s, 1234)
def test_tensor_with_side_effecting_condition(self):
v = variables.Variable(0)
v = self.variable('v', 0, dtypes.int32)
# tf.function required for the automatic control dependencies.
@def_function.function
def test_fn():
def cond():
v.assign(v.read_value() * 10 + i)
return i < n
def cond():
v.assign(v.read_value() * 10 + i)
return i < n
def body():
nonlocal i
i += 1
def body():
nonlocal i
i += 1
def set_state(loop_vars):
nonlocal i
i, = loop_vars
def set_state(loop_vars):
nonlocal i
i, = loop_vars
i = 0
n = constant_op.constant(5)
control_flow.while_stmt(
test=cond,
body=body,
get_state=lambda: (i,),
set_state=set_state,
symbol_names=('i',),
opts={})
return i
i = 0
n = constant_op.constant(5)
control_flow.while_stmt(
test=cond,
body=body,
get_state=lambda: (i,),
set_state=set_state,
symbol_names=('i',),
opts={})
self.evaluate(v.initializer)
self.assertEqual(self.evaluate(test_fn()), (5,))
self.assertEqual(self.evaluate(v), (12345,))
self.assertEqual(i, (5,))
self.assertEqual(v, (12345,))
def test_tensor_with_python_state(self):
class MutableObject(object):
@ -642,7 +613,8 @@ class WhileLoopTest(test.TestCase):
set_state=set_state,
symbol_names=('i', 'state.field'),
opts={})
self.assertEqual(self.evaluate((i, state.field)), (5, 1234))
self.assertEqual(i, 5)
self.assertEqual(state.field, 1234)
def test_python(self):
def body():
@ -679,7 +651,7 @@ class WhileLoopTest(test.TestCase):
symbol_names=('i', 's'),
opts={})
self.assertEqual(i, 5)
self.assertEqual(self.evaluate(s), 1234)
self.assertEqual(s, 1234)
def test_python_while_infinite(self):
if not __debug__:
@ -800,8 +772,7 @@ class WhileLoopTest(test.TestCase):
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
@test_util.run_all_in_graph_and_eager_modes
class IfStmtTest(test.TestCase):
class IfStmtTest(testing.AutoGraphTestCase):
def test_tensor(self):
@ -829,8 +800,8 @@ class IfStmtTest(test.TestCase):
nouts=1)
return i
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
self.assertEqual(test_fn(constant_op.constant(True)), 1)
self.assertEqual(test_fn(constant_op.constant(False)), -1)
def test_tensor_no_outputs(self):
@ -858,8 +829,8 @@ class IfStmtTest(test.TestCase):
nouts=0)
return i
self.assertEqual(None, test_fn(constant_op.constant(True)))
self.assertEqual(None, test_fn(constant_op.constant(False)))
self.assertIsNone(test_fn(constant_op.constant(True)))
self.assertIsNone(test_fn(constant_op.constant(False)))
def test_tensor_multiple_returns(self):
@ -889,9 +860,8 @@ class IfStmtTest(test.TestCase):
nouts=2)
return i, j
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual((-1, -2),
self.evaluate(test_fn(constant_op.constant(False))))
self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
def test_python(self):
@ -915,8 +885,8 @@ class IfStmtTest(test.TestCase):
nouts=1)
return i
self.assertEqual(1, test_fn(True))
self.assertEqual(-1, test_fn(False))
self.assertEqual(test_fn(True), 1)
self.assertEqual(test_fn(False), -1)
def test_python_multiple_returns(self):
@ -942,8 +912,8 @@ class IfStmtTest(test.TestCase):
nouts=2)
return i, j
self.assertEqual((1, 2), test_fn(True))
self.assertEqual((-1, -2), test_fn(False))
self.assertEqual(test_fn(True), (1, 2))
self.assertEqual(test_fn(False), (-1, -2))
def _basic_cond(self, body_fn, else_fn):
def body():
@ -959,16 +929,14 @@ class IfStmtTest(test.TestCase):
x, = cond_vars
x = 0
# Eager cond had different semantics, we don't test those here.
with func_graph.FuncGraph('tmp').as_default():
control_flow.if_stmt(
cond=constant_op.constant(True),
body=body,
orelse=orelse,
get_state=lambda: (x,),
set_state=set_state,
symbol_names=('x',),
nouts=1)
control_flow.if_stmt(
cond=constant_op.constant(True),
body=body,
orelse=orelse,
get_state=lambda: (x,),
set_state=set_state,
symbol_names=('x',),
nouts=1)
return x
def test_tensor_none_output(self):

View File

@ -22,4 +22,3 @@ from tensorflow.python.autograph.utils.context_managers import control_dependenc
from tensorflow.python.autograph.utils.misc import alias_tensors
from tensorflow.python.autograph.utils.py_func import wrap_py_func
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
from tensorflow.python.autograph.utils.testing import fake_tf

View File

@ -18,20 +18,82 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import imp
import types
import unittest
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def fake_tf():
"""Creates a fake module that looks like TensorFlow, for testing."""
mod = imp.new_module('tensorflow')
mod_contents = {}
mod_contents.update(gen_math_ops.__dict__)
mod_contents.update(math_ops.__dict__)
mod_contents.update(ops.__dict__)
mod_contents.update(mod.__dict__)
mod.__dict__.update(mod_contents)
return mod
class AutoGraphTestCase(test.TestCase):
"""Tests specialized for AutoGraph, which run as tf.functions.
These tests use a staged programming-like approach: most of the test code runs
as-is inside a tf.function, but the assertions are lifted outside the
function, and run with the corresponding function values instead.
For example, the test:
def test_foo(self):
baz = bar();
self.assertEqual(baz, value)
is equivalent to writing:
def test_foo(self):
@tf.function
def test_fn():
baz = bar();
return baz, value
baz_actual, value_actual = test_fn()
self.assertEqual(baz_actual, value_actual)
"""
def __new__(cls, *args):
obj = super().__new__(cls)
for name in cls.__dict__:
if not name.startswith(unittest.TestLoader.testMethodPrefix):
continue
m = getattr(obj, name)
if callable(m):
wrapper = obj._run_as_tf_function(m)
setattr(obj, name, types.MethodType(wrapper, obj))
return obj
def _run_as_tf_function(self, fn):
def wrapper(self):
@def_function.function(autograph=False) # Testing autograph itself.
def fn_wrapper():
self.assertions = []
fn()
targets = [args for _, args in self.assertions]
return targets
actuals = self.evaluate(fn_wrapper())
for (_, args), value in zip(self.assertions, actuals):
args[:] = value
return wrapper
def variable(self, name, value, dtype):
with ops.init_scope():
if name not in self.variables:
self.variables[name] = variables.Variable(value, dtype=dtype)
self.evaluate(self.variables[name].initializer)
return self.variables[name]
def setUp(self):
super().setUp()
self.variables = {}
def tearDown(self):
for fn, args in self.assertions:
fn(*args)
super().tearDown()
def assertEqual(self, *args):
self.assertions.append((super().assertEqual, list(args)))