Clean up tests so that they run consistently across versions and modes.

PiperOrigin-RevId: 281583150
Change-Id: Ifb5c245473d92174b39d707c7c693f5242e76d7e
This commit is contained in:
Dan Moldovan 2019-11-20 13:02:51 -08:00 committed by TensorFlower Gardener
parent f7a7c799ab
commit fcc86274b5

View File

@ -39,79 +39,70 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ForLoopTest(test.TestCase): class ForLoopTest(test.TestCase):
def test_tensor(self): def test_tensor(self):
with ops.Graph().as_default(): s = control_flow.for_stmt(
s = control_flow.for_stmt( constant_op.constant([1, 2, 3, 4]),
constant_op.constant([1, 2, 3, 4]), extra_test=lambda s: True,
extra_test=lambda s: True, body=lambda i, s: (s * 10 + i,),
body=lambda i, s: (s * 10 + i,), get_state=lambda: (),
get_state=lambda: (), set_state=lambda _: None,
set_state=lambda _: None, init_vars=(0,))
init_vars=(0,)) self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor(self): def test_range_tensor(self):
with ops.Graph().as_default(): s = control_flow.for_stmt(
s = control_flow.for_stmt( math_ops.range(5),
math_ops.range(5), extra_test=lambda s: True,
extra_test=lambda s: True, body=lambda i, s: (s * 10 + i,),
body=lambda i, s: (s * 10 + i,), get_state=lambda: (),
get_state=lambda: (), set_state=lambda _: None,
set_state=lambda _: None, init_vars=(0,))
init_vars=(0,)) self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor_random_delta(self): def test_range_tensor_random_delta(self):
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
with ops.Graph().as_default(): s = control_flow.for_stmt(
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32) math_ops.range(0, 5, random_one),
s = control_flow.for_stmt( extra_test=lambda s: True,
math_ops.range(0, 5, random_one), body=lambda i, s: (s * 10 + i,),
extra_test=lambda s: True, get_state=lambda: (),
body=lambda i, s: (s * 10 + i,), set_state=lambda _: None,
get_state=lambda: (), init_vars=(0,))
set_state=lambda _: None, self.assertEqual(self.evaluate(s), (1234,))
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor_explicit_limit_delta(self): def test_range_tensor_explicit_limit_delta(self):
with ops.Graph().as_default(): s = control_flow.for_stmt(
s = control_flow.for_stmt( math_ops.range(-17, -3, 5),
math_ops.range(-17, -3, 5), extra_test=lambda s: True,
extra_test=lambda s: True, body=lambda i, s: (s * 100 + i,),
body=lambda i, s: (s * 100 + i,), get_state=lambda: (),
get_state=lambda: (), set_state=lambda _: None,
set_state=lambda _: None, init_vars=(0,))
init_vars=(0,)) self.assertEqual(self.evaluate(s), (-171207,))
self.assertEqual(self.evaluate(s), (-171207,))
def test_range_tensor_random_negative_delta(self): def test_range_tensor_random_negative_delta(self):
with ops.Graph().as_default(): random_neg_five = random_ops.random_uniform((), -5, -4, dtype=dtypes.int32)
random_neg_five = random_ops.random_uniform((), s = control_flow.for_stmt(
-5, math_ops.range(17, 3, random_neg_five),
-4, extra_test=lambda s: True,
dtype=dtypes.int32) body=lambda i, s: (s * 100 + i,),
s = control_flow.for_stmt( get_state=lambda: (),
math_ops.range(17, 3, random_neg_five), set_state=lambda _: None,
extra_test=lambda s: True, init_vars=(0,))
body=lambda i, s: (s * 100 + i,), self.assertEqual(self.evaluate(s), (171207,))
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
def test_range_tensor_negative_delta(self): def test_range_tensor_negative_delta(self):
with ops.Graph().as_default(): s = control_flow.for_stmt(
s = control_flow.for_stmt( math_ops.range(17, 3, -5),
math_ops.range(17, 3, -5), extra_test=lambda s: True,
extra_test=lambda s: True, body=lambda i, s: (s * 100 + i,),
body=lambda i, s: (s * 100 + i,), get_state=lambda: (),
get_state=lambda: (), set_state=lambda _: None,
set_state=lambda _: None, init_vars=(0,))
init_vars=(0,)) self.assertEqual(self.evaluate(s), (171207,))
self.assertEqual(self.evaluate(s), (171207,))
def test_tensor_with_extra_test_only_python_state(self): def test_tensor_with_extra_test_only_python_state(self):
class MutableObject(object): class MutableObject(object):
@ -151,15 +142,14 @@ class ForLoopTest(test.TestCase):
self.assertEqual(s, (1234,)) self.assertEqual(s, (1234,))
def test_tf_dataset(self): def test_tf_dataset(self):
with ops.Graph().as_default(): s = control_flow.for_stmt(
s = control_flow.for_stmt( dataset_ops.Dataset.range(5),
dataset_ops.Dataset.range(5), extra_test=None,
extra_test=None, body=lambda i, s: (s * 10 + i,),
body=lambda i, s: (s * 10 + i,), get_state=lambda: (),
get_state=lambda: (), set_state=lambda _: None,
set_state=lambda _: None, init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) self.assertEqual(self.evaluate(s), (1234,))
self.assertEqual(self.evaluate(s), (1234,))
def test_dataset_with_extra_test(self): def test_dataset_with_extra_test(self):
s = control_flow.for_stmt( s = control_flow.for_stmt(
@ -209,7 +199,6 @@ class ForLoopTest(test.TestCase):
init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
self.assertEqual(self.evaluate(s), (3,)) self.assertEqual(self.evaluate(s), (3,))
@test_util.run_v2_only
def test_tf_dataset_no_loop_vars(self): def test_tf_dataset_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int64) v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer) self.evaluate(v.initializer)
@ -217,7 +206,8 @@ class ForLoopTest(test.TestCase):
def stateless_with_side_effects(i): def stateless_with_side_effects(i):
v.assign(v.read_value() * 10 + i) v.assign(v.read_value() * 10 + i)
# function is important here, because ops test for its presence. # tf.function required for the automatic control dependencies, and because
# ops test for its presence.
@def_function.function(autograph=False) @def_function.function(autograph=False)
def test_fn(): def test_fn():
control_flow.for_stmt( control_flow.for_stmt(
@ -228,7 +218,7 @@ class ForLoopTest(test.TestCase):
set_state=lambda _: None, set_state=lambda _: None,
init_vars=()) init_vars=())
test_fn() self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234) self.assertEqual(self.evaluate(v.read_value()), 1234)
def test_tf_iterator(self): def test_tf_iterator(self):
@ -246,14 +236,14 @@ class ForLoopTest(test.TestCase):
s, = test_fn() s, = test_fn()
self.assertAllEqual(s, 1234) self.assertAllEqual(s, 1234)
@test_util.run_v2_only
def test_tf_iterator_no_loop_vars(self): def test_tf_iterator_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int64) v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer)
def stateless_with_side_effects(i): def stateless_with_side_effects(i):
v.assign(v.read_value() * 10 + i) v.assign(v.read_value() * 10 + i)
# graph-mode iterators are only supported inside tf.function. # tf.function required for the automatic control dependencies.
@def_function.function(autograph=False) @def_function.function(autograph=False)
def test_fn(): def test_fn():
control_flow.for_stmt( control_flow.for_stmt(
@ -264,13 +254,13 @@ class ForLoopTest(test.TestCase):
set_state=lambda _: None, set_state=lambda _: None,
init_vars=()) init_vars=())
test_fn() self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234) self.assertEqual(self.evaluate(v.read_value()), 1234)
@test_util.run_all_in_graph_and_eager_modes
class WhileLoopTest(test.TestCase): class WhileLoopTest(test.TestCase):
@test_util.run_deprecated_v1
def test_tensor(self): def test_tensor(self):
n = constant_op.constant(5) n = constant_op.constant(5)
results = control_flow.while_stmt( results = control_flow.while_stmt(
@ -282,7 +272,6 @@ class WhileLoopTest(test.TestCase):
self.assertEqual((5, 10), self.evaluate(results)) self.assertEqual((5, 10), self.evaluate(results))
def test_tensor_with_tf_side_effects_in_cond(self): def test_tensor_with_tf_side_effects_in_cond(self):
n = constant_op.constant(5, dtype=dtypes.int64) n = constant_op.constant(5, dtype=dtypes.int64)
v = variables.Variable(0, dtype=dtypes.int64) v = variables.Variable(0, dtype=dtypes.int64)
@ -290,7 +279,7 @@ class WhileLoopTest(test.TestCase):
v.assign(v.read_value() + 1) v.assign(v.read_value() + 1)
return v.read_value() return v.read_value()
# function is important here, because ops test for its presence. # tf.function required for the automatic control dependencies.
@def_function.function(autograph=False) @def_function.function(autograph=False)
def test_fn(): def test_fn():
return control_flow.while_stmt( return control_flow.while_stmt(
@ -332,7 +321,6 @@ class WhileLoopTest(test.TestCase):
self.assertEqual(self.evaluate(s), (5, 10)) self.assertEqual(self.evaluate(s), (5, 10))
self.assertEqual(self.evaluate(state.field), 10) self.assertEqual(self.evaluate(state.field), 10)
@test_util.run_deprecated_v1
def test_python_with_tensor_state(self): def test_python_with_tensor_state(self):
n = 5 n = 5
results = control_flow.while_stmt( results = control_flow.while_stmt(
@ -386,47 +374,61 @@ class WhileLoopTest(test.TestCase):
out_capturer.getvalue())) out_capturer.getvalue()))
@test_util.run_all_in_graph_and_eager_modes
class IfStmtTest(test.TestCase): class IfStmtTest(test.TestCase):
def single_return_if_stmt(self, cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: 1,
orelse=lambda: -1,
get_state=lambda: (),
set_state=lambda _: None)
def multi_return_if_stmt(self, cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (1, 2),
orelse=lambda: (-1, -2),
get_state=lambda: (),
set_state=lambda _: None)
@test_util.run_deprecated_v1
def test_tensor(self): def test_tensor(self):
with self.cached_session():
t = self.single_return_if_stmt(constant_op.constant(True)) def test_fn(cond):
self.assertEqual(1, self.evaluate(t)) return control_flow.if_stmt(
t = self.single_return_if_stmt(constant_op.constant(False)) cond=cond,
self.assertEqual(-1, self.evaluate(t)) body=lambda: constant_op.constant(1),
orelse=lambda: constant_op.constant(-1),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
def test_tensor_multiple_returns(self):
def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (constant_op.constant(1), constant_op.constant(2)),
orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual((-1, -2),
self.evaluate(test_fn(constant_op.constant(False))))
def test_python(self): def test_python(self):
self.assertEqual(1, self.single_return_if_stmt(True))
self.assertEqual(-1, self.single_return_if_stmt(False))
@test_util.run_deprecated_v1 def test_fn(cond):
def test_tensor_multiple_returns(self): return control_flow.if_stmt(
with self.cached_session(): cond=cond,
t = self.multi_return_if_stmt(constant_op.constant(True)) body=lambda: 1,
self.assertAllEqual([1, 2], self.evaluate(t)) orelse=lambda: -1,
t = self.multi_return_if_stmt(constant_op.constant(False)) get_state=lambda: (),
self.assertAllEqual([-1, -2], self.evaluate(t)) set_state=lambda _: None)
self.assertEqual(1, test_fn(True))
self.assertEqual(-1, test_fn(False))
def test_python_multiple_returns(self): def test_python_multiple_returns(self):
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
self.assertEqual((-1, -2), self.multi_return_if_stmt(False)) def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (1, 2),
orelse=lambda: (-1, -2),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual((1, 2), test_fn(True))
self.assertEqual((-1, -2), test_fn(False))
if __name__ == '__main__': if __name__ == '__main__':