Internal cleanup: set up boilerplate for tests which require running inside a tf.function. The boilerplate helps keep the tests clean, although its mechanism may be tricky.

PiperOrigin-RevId: 322444035
Change-Id: I56bbb1e1eb2feb65e16b8171d4a1bd20e1b190a5
This commit is contained in:
Dan Moldovan 2020-07-21 14:33:30 -07:00 committed by TensorFlower Gardener
parent cdc531263a
commit be88c5f8a7
3 changed files with 202 additions and 173 deletions

View File

@ -31,26 +31,22 @@ import six
from tensorflow.python.autograph.operators import control_flow from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.autograph.operators import variables as variable_operators from tensorflow.python.autograph.operators import variables as variable_operators
from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import testing
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class ForLoopTest(testing.AutoGraphTestCase):
class ForLoopTest(test.TestCase):
def test_tensor(self): def test_tensor(self):
def body(i): def body(i):
@ -70,7 +66,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
self.assertEqual(self.evaluate(s), (1234,)) self.assertEqual(s, (1234,))
def test_range_tensor(self): def test_range_tensor(self):
def body(i): def body(i):
@ -90,7 +86,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'iterate_names': 'i'}) opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (1234,)) self.assertEqual(s, (1234,))
def test_range_tensor_explicit_limit_delta(self): def test_range_tensor_explicit_limit_delta(self):
def body(i): def body(i):
@ -110,7 +106,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'iterate_names': 'i'}) opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (-171207,)) self.assertEqual(s, (-171207,))
def test_range_tensor_explicit_limit_negative_delta(self): def test_range_tensor_explicit_limit_negative_delta(self):
def body(i): def body(i):
@ -130,7 +126,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'iterate_names': 'i'}) opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (171207,)) self.assertEqual(s, (171207,))
def test_range_tensor_random_delta(self): def test_range_tensor_random_delta(self):
def body(i): def body(i):
@ -151,7 +147,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'iterate_names': 'i'}) opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (1234,)) self.assertEqual(s, (1234,))
def test_range_tensor_random_negative_delta(self): def test_range_tensor_random_negative_delta(self):
def body(i): def body(i):
@ -172,7 +168,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'iterate_names': 'i'}) opts={'iterate_names': 'i'})
self.assertEqual(self.evaluate(s), (171207,)) self.assertEqual(s, (171207,))
def test_tensor_with_extra_test_object_vars(self): def test_tensor_with_extra_test_object_vars(self):
class MutableObject(object): class MutableObject(object):
@ -198,7 +194,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('state.field_1', 'state.field_2'), symbol_names=('state.field_1', 'state.field_2'),
opts={}) opts={})
self.assertEqual(self.evaluate((state.field_1, state.field_2)), (6, 6)) self.assertEqual((state.field_1, state.field_2), (6, 6))
def test_python(self): def test_python(self):
def body(i): def body(i):
@ -297,7 +293,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
self.assertEqual(self.evaluate(s), (1234,)) self.assertEqual(s, (1234,))
def test_dataset_with_extra_test(self): def test_dataset_with_extra_test(self):
def body(i): def body(i):
@ -317,7 +313,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
self.assertEqual(self.evaluate(s), (12,)) self.assertEqual(s, (12,))
def test_dataset_with_extra_test_collection_vars(self): def test_dataset_with_extra_test_collection_vars(self):
def body(i): def body(i):
@ -339,7 +335,7 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('l[0]', 's'), symbol_names=('l[0]', 's'),
opts={}) opts={})
self.assertEqual(self.evaluate((l[0], s)), (3, 3)) self.assertEqual((l[0], s), (3, 3))
def test_dataset_with_extra_test_iteration_limiting(self): def test_dataset_with_extra_test_iteration_limiting(self):
def body(it): def body(it):
@ -360,100 +356,83 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('i',), symbol_names=('i',),
opts={}) opts={})
self.assertEqual(self.evaluate(i), (3,)) self.assertEqual(i, (3,))
def test_tf_dataset_no_loop_vars(self): def test_tf_dataset_no_loop_vars(self):
def body(i): def body(i):
v.assign(v.read_value() * 10 + i) v.assign(v.read_value() * 10 + i)
v = variables.Variable(0, dtype=dtypes.int64) v = self.variable('v', 0, dtypes.int64)
self.evaluate(v.initializer)
# tf.function required for the automatic control dependencies, and because control_flow.for_stmt(
# ops test for its presence. dataset_ops.Dataset.range(5),
@def_function.function extra_test=None,
def test_fn(): body=body,
control_flow.for_stmt( get_state=lambda: (),
dataset_ops.Dataset.range(5), set_state=lambda _: None,
extra_test=None, symbol_names=(),
body=body, opts={})
get_state=lambda: (),
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn()) self.assertEqual(v.read_value(), 1234)
self.assertEqual(self.evaluate(v.read_value()), 1234)
def test_tf_iterator(self): def test_tf_iterator(self):
# graph-mode iterators are only supported inside tf.function. def body(i):
@def_function.function nonlocal s
def test_fn(): s = s * 10 + i
def body(i):
nonlocal s
s = s * 10 + i
def set_state(loop_vars): def set_state(loop_vars):
nonlocal s nonlocal s
s, = loop_vars s, = loop_vars
s = constant_op.constant(0, dtype=dtypes.int64) s = constant_op.constant(0, dtype=dtypes.int64)
control_flow.for_stmt( control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)), iter(dataset_ops.Dataset.range(5)),
extra_test=None, extra_test=None,
body=body, body=body,
get_state=lambda: (s,), get_state=lambda: (s,),
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
return s
self.assertAllEqual(test_fn(), 1234) self.assertEqual(s, 1234)
def test_tf_iterator_shape_invariants(self): def test_tf_iterator_shape_invariants(self):
# graph-mode iterators are only supported inside tf.function. def body(i):
@def_function.function nonlocal s
def test_fn(): s = array_ops.concat([s, [i]], 0)
def body(i):
nonlocal s
s = array_ops.concat([s, [i]], 0)
def set_state(loop_vars): def set_state(loop_vars):
nonlocal s nonlocal s
s, = loop_vars s, = loop_vars
s = constant_op.constant([], dtype=dtypes.int64) s = constant_op.constant([], dtype=dtypes.int64)
control_flow.for_stmt( control_flow.for_stmt(
iter(dataset_ops.Dataset.range(5)), iter(dataset_ops.Dataset.range(5)),
extra_test=None, extra_test=None,
body=body, body=body,
get_state=lambda: (s,), get_state=lambda: (s,),
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]}) opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})
return s
self.assertAllEqual(test_fn(), [0, 1, 2, 3, 4]) self.assertAllEqual(s, [0, 1, 2, 3, 4])
def test_tf_iterator_no_loop_vars(self): def test_tf_iterator_no_loop_vars(self):
def body(i): def body(i):
v.assign(v.read_value() * 10 + i) v.assign(v.read_value() * 10 + i)
v = variables.Variable(0, dtype=dtypes.int64) v = self.variable('v', 0, dtypes.int64)
self.evaluate(v.initializer)
# tf.function required for the automatic control dependencies. control_flow.for_stmt(
@def_function.function iter(dataset_ops.Dataset.range(5)),
def test_fn(): extra_test=None,
control_flow.for_stmt( body=body,
iter(dataset_ops.Dataset.range(5)), get_state=lambda: (),
extra_test=None, set_state=lambda _: None,
body=body, symbol_names=(),
get_state=lambda: (), opts={})
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn()) self.assertEqual(v.read_value(), 1234)
self.assertEqual(self.evaluate(v.read_value()), 1234)
def test_tf_ragged_tensor(self): def test_tf_ragged_tensor(self):
def body(i): def body(i):
@ -473,7 +452,8 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
self.assertEqual(self.evaluate(s), (123,))
self.assertEqual(s, (123,))
def test_tf_ragged_tensor_higher_dimensional(self): def test_tf_ragged_tensor_higher_dimensional(self):
def body(i): def body(i):
@ -497,30 +477,26 @@ class ForLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('s',), symbol_names=('s',),
opts={}) opts={})
self.assertEqual(self.evaluate(s), (12,))
self.assertEqual(s, (12,))
def test_tf_ragged_tensor_no_loop_vars(self): def test_tf_ragged_tensor_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int32) v = self.variable('v', 0, dtypes.int32)
self.evaluate(v.initializer)
def body(i): def body(i):
v.assign(v.read_value() * 10 + i[0]) v.assign(v.read_value() * 10 + i[0])
# tf.function required for the automatic control dependencies. control_flow.for_stmt(
@def_function.function(autograph=False) ragged_factory_ops.constant([[1], [2, 4], [3]]),
def test_fn(): extra_test=None,
control_flow.for_stmt( body=body,
ragged_factory_ops.constant([[1], [2, 4], [3]]), get_state=lambda: (),
extra_test=None, set_state=lambda _: None,
body=body, symbol_names=(),
get_state=lambda: (), opts={})
set_state=lambda _: None,
symbol_names=(),
opts={})
self.evaluate(test_fn())
# Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row). # Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row).
self.assertEqual(self.evaluate(v.read_value()), 123) self.assertEqual(v.read_value(), 123)
def _basic_loop(self, init_value, body_fn): def _basic_loop(self, init_value, body_fn):
def body(i): def body(i):
@ -561,8 +537,7 @@ class ForLoopTest(test.TestCase):
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
@test_util.run_all_in_graph_and_eager_modes class WhileLoopTest(testing.AutoGraphTestCase):
class WhileLoopTest(test.TestCase):
def test_tensor(self): def test_tensor(self):
def body(): def body():
@ -584,40 +559,36 @@ class WhileLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('i', 's'), symbol_names=('i', 's'),
opts={}) opts={})
self.assertEqual(self.evaluate((i, s)), (5, 1234)) self.assertEqual(i, 5)
self.assertEqual(s, 1234)
def test_tensor_with_side_effecting_condition(self): def test_tensor_with_side_effecting_condition(self):
v = variables.Variable(0) v = self.variable('v', 0, dtypes.int32)
# tf.function required for the automatic control dependencies. def cond():
@def_function.function v.assign(v.read_value() * 10 + i)
def test_fn(): return i < n
def cond():
v.assign(v.read_value() * 10 + i)
return i < n
def body(): def body():
nonlocal i nonlocal i
i += 1 i += 1
def set_state(loop_vars): def set_state(loop_vars):
nonlocal i nonlocal i
i, = loop_vars i, = loop_vars
i = 0 i = 0
n = constant_op.constant(5) n = constant_op.constant(5)
control_flow.while_stmt( control_flow.while_stmt(
test=cond, test=cond,
body=body, body=body,
get_state=lambda: (i,), get_state=lambda: (i,),
set_state=set_state, set_state=set_state,
symbol_names=('i',), symbol_names=('i',),
opts={}) opts={})
return i
self.evaluate(v.initializer) self.assertEqual(i, (5,))
self.assertEqual(self.evaluate(test_fn()), (5,)) self.assertEqual(v, (12345,))
self.assertEqual(self.evaluate(v), (12345,))
def test_tensor_with_python_state(self): def test_tensor_with_python_state(self):
class MutableObject(object): class MutableObject(object):
@ -642,7 +613,8 @@ class WhileLoopTest(test.TestCase):
set_state=set_state, set_state=set_state,
symbol_names=('i', 'state.field'), symbol_names=('i', 'state.field'),
opts={}) opts={})
self.assertEqual(self.evaluate((i, state.field)), (5, 1234)) self.assertEqual(i, 5)
self.assertEqual(state.field, 1234)
def test_python(self): def test_python(self):
def body(): def body():
@ -679,7 +651,7 @@ class WhileLoopTest(test.TestCase):
symbol_names=('i', 's'), symbol_names=('i', 's'),
opts={}) opts={})
self.assertEqual(i, 5) self.assertEqual(i, 5)
self.assertEqual(self.evaluate(s), 1234) self.assertEqual(s, 1234)
def test_python_while_infinite(self): def test_python_while_infinite(self):
if not __debug__: if not __debug__:
@ -800,8 +772,7 @@ class WhileLoopTest(test.TestCase):
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
@test_util.run_all_in_graph_and_eager_modes class IfStmtTest(testing.AutoGraphTestCase):
class IfStmtTest(test.TestCase):
def test_tensor(self): def test_tensor(self):
@ -829,8 +800,8 @@ class IfStmtTest(test.TestCase):
nouts=1) nouts=1)
return i return i
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual(test_fn(constant_op.constant(True)), 1)
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False)))) self.assertEqual(test_fn(constant_op.constant(False)), -1)
def test_tensor_no_outputs(self): def test_tensor_no_outputs(self):
@ -858,8 +829,8 @@ class IfStmtTest(test.TestCase):
nouts=0) nouts=0)
return i return i
self.assertEqual(None, test_fn(constant_op.constant(True))) self.assertIsNone(test_fn(constant_op.constant(True)))
self.assertEqual(None, test_fn(constant_op.constant(False))) self.assertIsNone(test_fn(constant_op.constant(False)))
def test_tensor_multiple_returns(self): def test_tensor_multiple_returns(self):
@ -889,9 +860,8 @@ class IfStmtTest(test.TestCase):
nouts=2) nouts=2)
return i, j return i, j
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual(test_fn(constant_op.constant(True)), (1, 2))
self.assertEqual((-1, -2), self.assertEqual(test_fn(constant_op.constant(False)), (-1, -2))
self.evaluate(test_fn(constant_op.constant(False))))
def test_python(self): def test_python(self):
@ -915,8 +885,8 @@ class IfStmtTest(test.TestCase):
nouts=1) nouts=1)
return i return i
self.assertEqual(1, test_fn(True)) self.assertEqual(test_fn(True), 1)
self.assertEqual(-1, test_fn(False)) self.assertEqual(test_fn(False), -1)
def test_python_multiple_returns(self): def test_python_multiple_returns(self):
@ -942,8 +912,8 @@ class IfStmtTest(test.TestCase):
nouts=2) nouts=2)
return i, j return i, j
self.assertEqual((1, 2), test_fn(True)) self.assertEqual(test_fn(True), (1, 2))
self.assertEqual((-1, -2), test_fn(False)) self.assertEqual(test_fn(False), (-1, -2))
def _basic_cond(self, body_fn, else_fn): def _basic_cond(self, body_fn, else_fn):
def body(): def body():
@ -959,16 +929,14 @@ class IfStmtTest(test.TestCase):
x, = cond_vars x, = cond_vars
x = 0 x = 0
# Eager cond had different semantics, we don't test those here. control_flow.if_stmt(
with func_graph.FuncGraph('tmp').as_default(): cond=constant_op.constant(True),
control_flow.if_stmt( body=body,
cond=constant_op.constant(True), orelse=orelse,
body=body, get_state=lambda: (x,),
orelse=orelse, set_state=set_state,
get_state=lambda: (x,), symbol_names=('x',),
set_state=set_state, nouts=1)
symbol_names=('x',),
nouts=1)
return x return x
def test_tensor_none_output(self): def test_tensor_none_output(self):

View File

@ -22,4 +22,3 @@ from tensorflow.python.autograph.utils.context_managers import control_dependenc
from tensorflow.python.autograph.utils.misc import alias_tensors from tensorflow.python.autograph.utils.misc import alias_tensors
from tensorflow.python.autograph.utils.py_func import wrap_py_func from tensorflow.python.autograph.utils.py_func import wrap_py_func
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
from tensorflow.python.autograph.utils.testing import fake_tf

View File

@ -18,20 +18,82 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import imp import types
import unittest
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import variables
from tensorflow.python.ops import math_ops from tensorflow.python.platform import test
def fake_tf(): class AutoGraphTestCase(test.TestCase):
"""Creates a fake module that looks like TensorFlow, for testing.""" """Tests specialized for AutoGraph, which run as tf.functions.
mod = imp.new_module('tensorflow')
mod_contents = {} These tests use a staged programming-like approach: most of the test code runs
mod_contents.update(gen_math_ops.__dict__) as-is inside a tf.function, but the assertions are lifted outside the
mod_contents.update(math_ops.__dict__) function, and run with the corresponding function values instead.
mod_contents.update(ops.__dict__)
mod_contents.update(mod.__dict__) For example, the test:
mod.__dict__.update(mod_contents)
return mod def test_foo(self):
baz = bar();
self.assertEqual(baz, value)
is equivalent to writing:
def test_foo(self):
@tf.function
def test_fn():
baz = bar();
return baz, value
baz_actual, value_actual = test_fn()
self.assertEqual(baz_actual, value_actual)
"""
def __new__(cls, *args):
obj = super().__new__(cls)
for name in cls.__dict__:
if not name.startswith(unittest.TestLoader.testMethodPrefix):
continue
m = getattr(obj, name)
if callable(m):
wrapper = obj._run_as_tf_function(m)
setattr(obj, name, types.MethodType(wrapper, obj))
return obj
def _run_as_tf_function(self, fn):
def wrapper(self):
@def_function.function(autograph=False) # Testing autograph itself.
def fn_wrapper():
self.assertions = []
fn()
targets = [args for _, args in self.assertions]
return targets
actuals = self.evaluate(fn_wrapper())
for (_, args), value in zip(self.assertions, actuals):
args[:] = value
return wrapper
def variable(self, name, value, dtype):
with ops.init_scope():
if name not in self.variables:
self.variables[name] = variables.Variable(value, dtype=dtype)
self.evaluate(self.variables[name].initializer)
return self.variables[name]
def setUp(self):
super().setUp()
self.variables = {}
def tearDown(self):
for fn, args in self.assertions:
fn(*args)
super().tearDown()
def assertEqual(self, *args):
self.assertions.append((super().assertEqual, list(args)))