Clean up tests so that they run consistently across versions and modes.

PiperOrigin-RevId: 281583150
Change-Id: Ifb5c245473d92174b39d707c7c693f5242e76d7e
This commit is contained in:
Dan Moldovan 2019-11-20 13:02:51 -08:00 committed by TensorFlower Gardener
parent f7a7c799ab
commit fcc86274b5

View File

@ -39,79 +39,70 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ForLoopTest(test.TestCase):
def test_tensor(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
constant_op.constant([1, 2, 3, 4]),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
s = control_flow.for_stmt(
constant_op.constant([1, 2, 3, 4]),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(5),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
s = control_flow.for_stmt(
math_ops.range(5),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor_random_delta(self):
with ops.Graph().as_default():
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
s = control_flow.for_stmt(
math_ops.range(0, 5, random_one),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
s = control_flow.for_stmt(
math_ops.range(0, 5, random_one),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor_explicit_limit_delta(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(-17, -3, 5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (-171207,))
s = control_flow.for_stmt(
math_ops.range(-17, -3, 5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (-171207,))
def test_range_tensor_random_negative_delta(self):
with ops.Graph().as_default():
random_neg_five = random_ops.random_uniform((),
-5,
-4,
dtype=dtypes.int32)
s = control_flow.for_stmt(
math_ops.range(17, 3, random_neg_five),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
random_neg_five = random_ops.random_uniform((), -5, -4, dtype=dtypes.int32)
s = control_flow.for_stmt(
math_ops.range(17, 3, random_neg_five),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
def test_range_tensor_negative_delta(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(17, 3, -5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
s = control_flow.for_stmt(
math_ops.range(17, 3, -5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
def test_tensor_with_extra_test_only_python_state(self):
class MutableObject(object):
@ -151,15 +142,14 @@ class ForLoopTest(test.TestCase):
self.assertEqual(s, (1234,))
def test_tf_dataset(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
dataset_ops.Dataset.range(5),
extra_test=None,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
self.assertEqual(self.evaluate(s), (1234,))
s = control_flow.for_stmt(
dataset_ops.Dataset.range(5),
extra_test=None,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
self.assertEqual(self.evaluate(s), (1234,))
def test_dataset_with_extra_test(self):
s = control_flow.for_stmt(
@ -209,7 +199,6 @@ class ForLoopTest(test.TestCase):
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
self.assertEqual(self.evaluate(s), (3,))
@test_util.run_v2_only
def test_tf_dataset_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer)
@ -217,7 +206,8 @@ class ForLoopTest(test.TestCase):
def stateless_with_side_effects(i):
v.assign(v.read_value() * 10 + i)
# function is important here, because ops test for its presence.
# tf.function required for the automatic control dependencies, and because
# ops test for its presence.
@def_function.function(autograph=False)
def test_fn():
control_flow.for_stmt(
@ -228,7 +218,7 @@ class ForLoopTest(test.TestCase):
set_state=lambda _: None,
init_vars=())
test_fn()
self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234)
def test_tf_iterator(self):
@ -246,14 +236,14 @@ class ForLoopTest(test.TestCase):
s, = test_fn()
self.assertAllEqual(s, 1234)
@test_util.run_v2_only
def test_tf_iterator_no_loop_vars(self):
v = variables.Variable(0, dtype=dtypes.int64)
self.evaluate(v.initializer)
def stateless_with_side_effects(i):
v.assign(v.read_value() * 10 + i)
# graph-mode iterators are only supported inside tf.function.
# tf.function required for the automatic control dependencies.
@def_function.function(autograph=False)
def test_fn():
control_flow.for_stmt(
@ -264,13 +254,13 @@ class ForLoopTest(test.TestCase):
set_state=lambda _: None,
init_vars=())
test_fn()
self.evaluate(test_fn())
self.assertEqual(self.evaluate(v.read_value()), 1234)
@test_util.run_all_in_graph_and_eager_modes
class WhileLoopTest(test.TestCase):
@test_util.run_deprecated_v1
def test_tensor(self):
n = constant_op.constant(5)
results = control_flow.while_stmt(
@ -282,7 +272,6 @@ class WhileLoopTest(test.TestCase):
self.assertEqual((5, 10), self.evaluate(results))
def test_tensor_with_tf_side_effects_in_cond(self):
n = constant_op.constant(5, dtype=dtypes.int64)
v = variables.Variable(0, dtype=dtypes.int64)
@ -290,7 +279,7 @@ class WhileLoopTest(test.TestCase):
v.assign(v.read_value() + 1)
return v.read_value()
# function is important here, because ops test for its presence.
# tf.function required for the automatic control dependencies.
@def_function.function(autograph=False)
def test_fn():
return control_flow.while_stmt(
@ -332,7 +321,6 @@ class WhileLoopTest(test.TestCase):
self.assertEqual(self.evaluate(s), (5, 10))
self.assertEqual(self.evaluate(state.field), 10)
@test_util.run_deprecated_v1
def test_python_with_tensor_state(self):
n = 5
results = control_flow.while_stmt(
@ -386,47 +374,61 @@ class WhileLoopTest(test.TestCase):
out_capturer.getvalue()))
@test_util.run_all_in_graph_and_eager_modes
class IfStmtTest(test.TestCase):
def single_return_if_stmt(self, cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: 1,
orelse=lambda: -1,
get_state=lambda: (),
set_state=lambda _: None)
def multi_return_if_stmt(self, cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (1, 2),
orelse=lambda: (-1, -2),
get_state=lambda: (),
set_state=lambda _: None)
@test_util.run_deprecated_v1
def test_tensor(self):
with self.cached_session():
t = self.single_return_if_stmt(constant_op.constant(True))
self.assertEqual(1, self.evaluate(t))
t = self.single_return_if_stmt(constant_op.constant(False))
self.assertEqual(-1, self.evaluate(t))
def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: constant_op.constant(1),
orelse=lambda: constant_op.constant(-1),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
def test_tensor_multiple_returns(self):
def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (constant_op.constant(1), constant_op.constant(2)),
orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
self.assertEqual((-1, -2),
self.evaluate(test_fn(constant_op.constant(False))))
def test_python(self):
self.assertEqual(1, self.single_return_if_stmt(True))
self.assertEqual(-1, self.single_return_if_stmt(False))
@test_util.run_deprecated_v1
def test_tensor_multiple_returns(self):
with self.cached_session():
t = self.multi_return_if_stmt(constant_op.constant(True))
self.assertAllEqual([1, 2], self.evaluate(t))
t = self.multi_return_if_stmt(constant_op.constant(False))
self.assertAllEqual([-1, -2], self.evaluate(t))
def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: 1,
orelse=lambda: -1,
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual(1, test_fn(True))
self.assertEqual(-1, test_fn(False))
def test_python_multiple_returns(self):
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
self.assertEqual((-1, -2), self.multi_return_if_stmt(False))
def test_fn(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: (1, 2),
orelse=lambda: (-1, -2),
get_state=lambda: (),
set_state=lambda _: None)
self.assertEqual((1, 2), test_fn(True))
self.assertEqual((-1, -2), test_fn(False))
if __name__ == '__main__':