Clean up tests so that they run consistently across versions and modes.
PiperOrigin-RevId: 281583150 Change-Id: Ifb5c245473d92174b39d707c7c693f5242e76d7e
This commit is contained in:
parent
f7a7c799ab
commit
fcc86274b5
@ -39,79 +39,70 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class ForLoopTest(test.TestCase):
|
class ForLoopTest(test.TestCase):
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
s = control_flow.for_stmt(
|
constant_op.constant([1, 2, 3, 4]),
|
||||||
constant_op.constant([1, 2, 3, 4]),
|
extra_test=lambda s: True,
|
||||||
extra_test=lambda s: True,
|
body=lambda i, s: (s * 10 + i,),
|
||||||
body=lambda i, s: (s * 10 + i,),
|
get_state=lambda: (),
|
||||||
get_state=lambda: (),
|
set_state=lambda _: None,
|
||||||
set_state=lambda _: None,
|
init_vars=(0,))
|
||||||
init_vars=(0,))
|
self.assertEqual(self.evaluate(s), (1234,))
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
|
||||||
|
|
||||||
def test_range_tensor(self):
|
def test_range_tensor(self):
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
s = control_flow.for_stmt(
|
math_ops.range(5),
|
||||||
math_ops.range(5),
|
extra_test=lambda s: True,
|
||||||
extra_test=lambda s: True,
|
body=lambda i, s: (s * 10 + i,),
|
||||||
body=lambda i, s: (s * 10 + i,),
|
get_state=lambda: (),
|
||||||
get_state=lambda: (),
|
set_state=lambda _: None,
|
||||||
set_state=lambda _: None,
|
init_vars=(0,))
|
||||||
init_vars=(0,))
|
self.assertEqual(self.evaluate(s), (1234,))
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
|
||||||
|
|
||||||
def test_range_tensor_random_delta(self):
|
def test_range_tensor_random_delta(self):
|
||||||
|
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
|
math_ops.range(0, 5, random_one),
|
||||||
s = control_flow.for_stmt(
|
extra_test=lambda s: True,
|
||||||
math_ops.range(0, 5, random_one),
|
body=lambda i, s: (s * 10 + i,),
|
||||||
extra_test=lambda s: True,
|
get_state=lambda: (),
|
||||||
body=lambda i, s: (s * 10 + i,),
|
set_state=lambda _: None,
|
||||||
get_state=lambda: (),
|
init_vars=(0,))
|
||||||
set_state=lambda _: None,
|
self.assertEqual(self.evaluate(s), (1234,))
|
||||||
init_vars=(0,))
|
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
|
||||||
|
|
||||||
def test_range_tensor_explicit_limit_delta(self):
|
def test_range_tensor_explicit_limit_delta(self):
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
s = control_flow.for_stmt(
|
math_ops.range(-17, -3, 5),
|
||||||
math_ops.range(-17, -3, 5),
|
extra_test=lambda s: True,
|
||||||
extra_test=lambda s: True,
|
body=lambda i, s: (s * 100 + i,),
|
||||||
body=lambda i, s: (s * 100 + i,),
|
get_state=lambda: (),
|
||||||
get_state=lambda: (),
|
set_state=lambda _: None,
|
||||||
set_state=lambda _: None,
|
init_vars=(0,))
|
||||||
init_vars=(0,))
|
self.assertEqual(self.evaluate(s), (-171207,))
|
||||||
self.assertEqual(self.evaluate(s), (-171207,))
|
|
||||||
|
|
||||||
def test_range_tensor_random_negative_delta(self):
|
def test_range_tensor_random_negative_delta(self):
|
||||||
with ops.Graph().as_default():
|
random_neg_five = random_ops.random_uniform((), -5, -4, dtype=dtypes.int32)
|
||||||
random_neg_five = random_ops.random_uniform((),
|
s = control_flow.for_stmt(
|
||||||
-5,
|
math_ops.range(17, 3, random_neg_five),
|
||||||
-4,
|
extra_test=lambda s: True,
|
||||||
dtype=dtypes.int32)
|
body=lambda i, s: (s * 100 + i,),
|
||||||
s = control_flow.for_stmt(
|
get_state=lambda: (),
|
||||||
math_ops.range(17, 3, random_neg_five),
|
set_state=lambda _: None,
|
||||||
extra_test=lambda s: True,
|
init_vars=(0,))
|
||||||
body=lambda i, s: (s * 100 + i,),
|
self.assertEqual(self.evaluate(s), (171207,))
|
||||||
get_state=lambda: (),
|
|
||||||
set_state=lambda _: None,
|
|
||||||
init_vars=(0,))
|
|
||||||
self.assertEqual(self.evaluate(s), (171207,))
|
|
||||||
|
|
||||||
def test_range_tensor_negative_delta(self):
|
def test_range_tensor_negative_delta(self):
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
s = control_flow.for_stmt(
|
math_ops.range(17, 3, -5),
|
||||||
math_ops.range(17, 3, -5),
|
extra_test=lambda s: True,
|
||||||
extra_test=lambda s: True,
|
body=lambda i, s: (s * 100 + i,),
|
||||||
body=lambda i, s: (s * 100 + i,),
|
get_state=lambda: (),
|
||||||
get_state=lambda: (),
|
set_state=lambda _: None,
|
||||||
set_state=lambda _: None,
|
init_vars=(0,))
|
||||||
init_vars=(0,))
|
self.assertEqual(self.evaluate(s), (171207,))
|
||||||
self.assertEqual(self.evaluate(s), (171207,))
|
|
||||||
|
|
||||||
def test_tensor_with_extra_test_only_python_state(self):
|
def test_tensor_with_extra_test_only_python_state(self):
|
||||||
class MutableObject(object):
|
class MutableObject(object):
|
||||||
@ -151,15 +142,14 @@ class ForLoopTest(test.TestCase):
|
|||||||
self.assertEqual(s, (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
def test_tf_dataset(self):
|
def test_tf_dataset(self):
|
||||||
with ops.Graph().as_default():
|
s = control_flow.for_stmt(
|
||||||
s = control_flow.for_stmt(
|
dataset_ops.Dataset.range(5),
|
||||||
dataset_ops.Dataset.range(5),
|
extra_test=None,
|
||||||
extra_test=None,
|
body=lambda i, s: (s * 10 + i,),
|
||||||
body=lambda i, s: (s * 10 + i,),
|
get_state=lambda: (),
|
||||||
get_state=lambda: (),
|
set_state=lambda _: None,
|
||||||
set_state=lambda _: None,
|
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
||||||
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
self.assertEqual(self.evaluate(s), (1234,))
|
||||||
self.assertEqual(self.evaluate(s), (1234,))
|
|
||||||
|
|
||||||
def test_dataset_with_extra_test(self):
|
def test_dataset_with_extra_test(self):
|
||||||
s = control_flow.for_stmt(
|
s = control_flow.for_stmt(
|
||||||
@ -209,7 +199,6 @@ class ForLoopTest(test.TestCase):
|
|||||||
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
init_vars=(constant_op.constant(0, dtype=dtypes.int64),))
|
||||||
self.assertEqual(self.evaluate(s), (3,))
|
self.assertEqual(self.evaluate(s), (3,))
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def test_tf_dataset_no_loop_vars(self):
|
def test_tf_dataset_no_loop_vars(self):
|
||||||
v = variables.Variable(0, dtype=dtypes.int64)
|
v = variables.Variable(0, dtype=dtypes.int64)
|
||||||
self.evaluate(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
@ -217,7 +206,8 @@ class ForLoopTest(test.TestCase):
|
|||||||
def stateless_with_side_effects(i):
|
def stateless_with_side_effects(i):
|
||||||
v.assign(v.read_value() * 10 + i)
|
v.assign(v.read_value() * 10 + i)
|
||||||
|
|
||||||
# function is important here, because ops test for its presence.
|
# tf.function required for the automatic control dependencies, and because
|
||||||
|
# ops test for its presence.
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
def test_fn():
|
def test_fn():
|
||||||
control_flow.for_stmt(
|
control_flow.for_stmt(
|
||||||
@ -228,7 +218,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=lambda _: None,
|
set_state=lambda _: None,
|
||||||
init_vars=())
|
init_vars=())
|
||||||
|
|
||||||
test_fn()
|
self.evaluate(test_fn())
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||||
|
|
||||||
def test_tf_iterator(self):
|
def test_tf_iterator(self):
|
||||||
@ -246,14 +236,14 @@ class ForLoopTest(test.TestCase):
|
|||||||
s, = test_fn()
|
s, = test_fn()
|
||||||
self.assertAllEqual(s, 1234)
|
self.assertAllEqual(s, 1234)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def test_tf_iterator_no_loop_vars(self):
|
def test_tf_iterator_no_loop_vars(self):
|
||||||
v = variables.Variable(0, dtype=dtypes.int64)
|
v = variables.Variable(0, dtype=dtypes.int64)
|
||||||
|
self.evaluate(v.initializer)
|
||||||
|
|
||||||
def stateless_with_side_effects(i):
|
def stateless_with_side_effects(i):
|
||||||
v.assign(v.read_value() * 10 + i)
|
v.assign(v.read_value() * 10 + i)
|
||||||
|
|
||||||
# graph-mode iterators are only supported inside tf.function.
|
# tf.function required for the automatic control dependencies.
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
def test_fn():
|
def test_fn():
|
||||||
control_flow.for_stmt(
|
control_flow.for_stmt(
|
||||||
@ -264,13 +254,13 @@ class ForLoopTest(test.TestCase):
|
|||||||
set_state=lambda _: None,
|
set_state=lambda _: None,
|
||||||
init_vars=())
|
init_vars=())
|
||||||
|
|
||||||
test_fn()
|
self.evaluate(test_fn())
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
self.assertEqual(self.evaluate(v.read_value()), 1234)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class WhileLoopTest(test.TestCase):
|
class WhileLoopTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
n = constant_op.constant(5)
|
n = constant_op.constant(5)
|
||||||
results = control_flow.while_stmt(
|
results = control_flow.while_stmt(
|
||||||
@ -282,7 +272,6 @@ class WhileLoopTest(test.TestCase):
|
|||||||
self.assertEqual((5, 10), self.evaluate(results))
|
self.assertEqual((5, 10), self.evaluate(results))
|
||||||
|
|
||||||
def test_tensor_with_tf_side_effects_in_cond(self):
|
def test_tensor_with_tf_side_effects_in_cond(self):
|
||||||
|
|
||||||
n = constant_op.constant(5, dtype=dtypes.int64)
|
n = constant_op.constant(5, dtype=dtypes.int64)
|
||||||
v = variables.Variable(0, dtype=dtypes.int64)
|
v = variables.Variable(0, dtype=dtypes.int64)
|
||||||
|
|
||||||
@ -290,7 +279,7 @@ class WhileLoopTest(test.TestCase):
|
|||||||
v.assign(v.read_value() + 1)
|
v.assign(v.read_value() + 1)
|
||||||
return v.read_value()
|
return v.read_value()
|
||||||
|
|
||||||
# function is important here, because ops test for its presence.
|
# tf.function required for the automatic control dependencies.
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
def test_fn():
|
def test_fn():
|
||||||
return control_flow.while_stmt(
|
return control_flow.while_stmt(
|
||||||
@ -332,7 +321,6 @@ class WhileLoopTest(test.TestCase):
|
|||||||
self.assertEqual(self.evaluate(s), (5, 10))
|
self.assertEqual(self.evaluate(s), (5, 10))
|
||||||
self.assertEqual(self.evaluate(state.field), 10)
|
self.assertEqual(self.evaluate(state.field), 10)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_python_with_tensor_state(self):
|
def test_python_with_tensor_state(self):
|
||||||
n = 5
|
n = 5
|
||||||
results = control_flow.while_stmt(
|
results = control_flow.while_stmt(
|
||||||
@ -386,47 +374,61 @@ class WhileLoopTest(test.TestCase):
|
|||||||
out_capturer.getvalue()))
|
out_capturer.getvalue()))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IfStmtTest(test.TestCase):
|
class IfStmtTest(test.TestCase):
|
||||||
|
|
||||||
def single_return_if_stmt(self, cond):
|
|
||||||
return control_flow.if_stmt(
|
|
||||||
cond=cond,
|
|
||||||
body=lambda: 1,
|
|
||||||
orelse=lambda: -1,
|
|
||||||
get_state=lambda: (),
|
|
||||||
set_state=lambda _: None)
|
|
||||||
|
|
||||||
def multi_return_if_stmt(self, cond):
|
|
||||||
return control_flow.if_stmt(
|
|
||||||
cond=cond,
|
|
||||||
body=lambda: (1, 2),
|
|
||||||
orelse=lambda: (-1, -2),
|
|
||||||
get_state=lambda: (),
|
|
||||||
set_state=lambda _: None)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
with self.cached_session():
|
|
||||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
def test_fn(cond):
|
||||||
self.assertEqual(1, self.evaluate(t))
|
return control_flow.if_stmt(
|
||||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
cond=cond,
|
||||||
self.assertEqual(-1, self.evaluate(t))
|
body=lambda: constant_op.constant(1),
|
||||||
|
orelse=lambda: constant_op.constant(-1),
|
||||||
|
get_state=lambda: (),
|
||||||
|
set_state=lambda _: None)
|
||||||
|
|
||||||
|
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
|
||||||
|
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
|
||||||
|
|
||||||
|
def test_tensor_multiple_returns(self):
|
||||||
|
|
||||||
|
def test_fn(cond):
|
||||||
|
return control_flow.if_stmt(
|
||||||
|
cond=cond,
|
||||||
|
body=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
|
||||||
|
get_state=lambda: (),
|
||||||
|
set_state=lambda _: None)
|
||||||
|
|
||||||
|
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
|
||||||
|
self.assertEqual((-1, -2),
|
||||||
|
self.evaluate(test_fn(constant_op.constant(False))))
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
|
||||||
self.assertEqual(-1, self.single_return_if_stmt(False))
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def test_fn(cond):
|
||||||
def test_tensor_multiple_returns(self):
|
return control_flow.if_stmt(
|
||||||
with self.cached_session():
|
cond=cond,
|
||||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
body=lambda: 1,
|
||||||
self.assertAllEqual([1, 2], self.evaluate(t))
|
orelse=lambda: -1,
|
||||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
get_state=lambda: (),
|
||||||
self.assertAllEqual([-1, -2], self.evaluate(t))
|
set_state=lambda _: None)
|
||||||
|
|
||||||
|
self.assertEqual(1, test_fn(True))
|
||||||
|
self.assertEqual(-1, test_fn(False))
|
||||||
|
|
||||||
def test_python_multiple_returns(self):
|
def test_python_multiple_returns(self):
|
||||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
|
||||||
self.assertEqual((-1, -2), self.multi_return_if_stmt(False))
|
def test_fn(cond):
|
||||||
|
return control_flow.if_stmt(
|
||||||
|
cond=cond,
|
||||||
|
body=lambda: (1, 2),
|
||||||
|
orelse=lambda: (-1, -2),
|
||||||
|
get_state=lambda: (),
|
||||||
|
set_state=lambda _: None)
|
||||||
|
|
||||||
|
self.assertEqual((1, 2), test_fn(True))
|
||||||
|
self.assertEqual((-1, -2), test_fn(False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user