Clean up tests so that they run consistently across versions and modes.
PiperOrigin-RevId: 281583150 Change-Id: Ifb5c245473d92174b39d707c7c693f5242e76d7e
This commit is contained in:
parent
f7a7c799ab
commit
fcc86274b5
@ -39,79 +39,70 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ForLoopTest(test.TestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
constant_op.constant([1, 2, 3, 4]),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
s = control_flow.for_stmt(
|
||||
constant_op.constant([1, 2, 3, 4]),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor_random_delta(self):
|
||||
|
||||
with ops.Graph().as_default():
|
||||
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(0, 5, random_one),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(0, 5, random_one),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor_explicit_limit_delta(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(-17, -3, 5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (-171207,))
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(-17, -3, 5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (-171207,))
|
||||
|
||||
def test_range_tensor_random_negative_delta(self):
|
||||
with ops.Graph().as_default():
|
||||
random_neg_five = random_ops.random_uniform((),
|
||||
-5,
|
||||
-4,
|
||||
dtype=dtypes.int32)
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(17, 3, random_neg_five),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
random_neg_five = random_ops.random_uniform((), -5, -4, dtype=dtypes.int32)
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(17, 3, random_neg_five),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
|
||||
def test_range_tensor_negative_delta(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(17, 3, -5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(17, 3, -5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
|
||||
def test_tensor_with_extra_test_only_python_state(self):
|
||||
class MutableObject(object):
|
||||
@ -151,15 +142,14 @@ class ForLoopTest(test.TestCase):
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_tf_dataset(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
dataset_ops.Dataset.range(5),
|
||||
extra_test=None,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
s = control_flow.for_stmt(
|
||||
dataset_ops.Dataset.range(5),
|
||||
extra_test=None,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_dataset_with_extra_test(self):
|
||||
s = control_flow.for_stmt(
|
||||
@ -209,7 +199,6 @@ class ForLoopTest(test.TestCase):
|
||||
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
||||
self.assertEqual(self.evaluate(s), (3,))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_tf_dataset_no_loop_vars(self):
|
||||
v = variables.Variable(0, dtype=dtypes.int64)
|
||||
self.evaluate(v.initializer)
|
||||
@ -217,7 +206,8 @@ class ForLoopTest(test.TestCase):
|
||||
def stateless_with_side_effects(i):
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
|
||||
# function is important here, because ops test for its presence.
|
||||
# tf.function required for the automatic control dependencies, and because
|
||||
# ops test for its presence.
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
control_flow.for_stmt(
|
||||
@ -228,7 +218,7 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=lambda _: None,
|
||||
init_vars=())
|
||||
|
||||
test_fn()
|
||||
self.evaluate(test_fn())
|
||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||
|
||||
def test_tf_iterator(self):
|
||||
@ -246,14 +236,14 @@ class ForLoopTest(test.TestCase):
|
||||
s, = test_fn()
|
||||
self.assertAllEqual(s, 1234)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_tf_iterator_no_loop_vars(self):
|
||||
v = variables.Variable(0, dtype=dtypes.int64)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
def stateless_with_side_effects(i):
|
||||
v.assign(v.read_value() * 10 + i)
|
||||
|
||||
# graph-mode iterators are only supported inside tf.function.
|
||||
# tf.function required for the automatic control dependencies.
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
control_flow.for_stmt(
|
||||
@ -264,13 +254,13 @@ class ForLoopTest(test.TestCase):
|
||||
set_state=lambda _: None,
|
||||
init_vars=())
|
||||
|
||||
test_fn()
|
||||
self.evaluate(test_fn())
|
||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WhileLoopTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_tensor(self):
|
||||
n = constant_op.constant(5)
|
||||
results = control_flow.while_stmt(
|
||||
@ -282,7 +272,6 @@ class WhileLoopTest(test.TestCase):
|
||||
self.assertEqual((5, 10), self.evaluate(results))
|
||||
|
||||
def test_tensor_with_tf_side_effects_in_cond(self):
|
||||
|
||||
n = constant_op.constant(5, dtype=dtypes.int64)
|
||||
v = variables.Variable(0, dtype=dtypes.int64)
|
||||
|
||||
@ -290,7 +279,7 @@ class WhileLoopTest(test.TestCase):
|
||||
v.assign(v.read_value() + 1)
|
||||
return v.read_value()
|
||||
|
||||
# function is important here, because ops test for its presence.
|
||||
# tf.function required for the automatic control dependencies.
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
return control_flow.while_stmt(
|
||||
@ -332,7 +321,6 @@ class WhileLoopTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(s), (5, 10))
|
||||
self.assertEqual(self.evaluate(state.field), 10)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_python_with_tensor_state(self):
|
||||
n = 5
|
||||
results = control_flow.while_stmt(
|
||||
@ -386,47 +374,61 @@ class WhileLoopTest(test.TestCase):
|
||||
out_capturer.getvalue()))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IfStmtTest(test.TestCase):
|
||||
|
||||
def single_return_if_stmt(self, cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: 1,
|
||||
orelse=lambda: -1,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
def multi_return_if_stmt(self, cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: (1, 2),
|
||||
orelse=lambda: (-1, -2),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_tensor(self):
|
||||
with self.cached_session():
|
||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||
self.assertEqual(1, self.evaluate(t))
|
||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||
self.assertEqual(-1, self.evaluate(t))
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: constant_op.constant(1),
|
||||
orelse=lambda: constant_op.constant(-1),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
|
||||
|
||||
def test_tensor_multiple_returns(self):
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual((-1, -2),
|
||||
self.evaluate(test_fn(constant_op.constant(False))))
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||
self.assertEqual(-1, self.single_return_if_stmt(False))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_tensor_multiple_returns(self):
|
||||
with self.cached_session():
|
||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||
self.assertAllEqual([1, 2], self.evaluate(t))
|
||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||
self.assertAllEqual([-1, -2], self.evaluate(t))
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: 1,
|
||||
orelse=lambda: -1,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
self.assertEqual(1, test_fn(True))
|
||||
self.assertEqual(-1, test_fn(False))
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||
self.assertEqual((-1, -2), self.multi_return_if_stmt(False))
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: (1, 2),
|
||||
orelse=lambda: (-1, -2),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None)
|
||||
|
||||
self.assertEqual((1, 2), test_fn(True))
|
||||
self.assertEqual((-1, -2), test_fn(False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user