diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 7b6217cf78e..64ed0210bf0 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -39,79 +39,70 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class ForLoopTest(test.TestCase): def test_tensor(self): - with ops.Graph().as_default(): - s = control_flow.for_stmt( - constant_op.constant([1, 2, 3, 4]), - extra_test=lambda s: True, - body=lambda i, s: (s * 10 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (1234,)) + s = control_flow.for_stmt( + constant_op.constant([1, 2, 3, 4]), + extra_test=lambda s: True, + body=lambda i, s: (s * 10 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor(self): - with ops.Graph().as_default(): - s = control_flow.for_stmt( - math_ops.range(5), - extra_test=lambda s: True, - body=lambda i, s: (s * 10 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (1234,)) + s = control_flow.for_stmt( + math_ops.range(5), + extra_test=lambda s: True, + body=lambda i, s: (s * 10 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor_random_delta(self): - - with ops.Graph().as_default(): - random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32) - s = control_flow.for_stmt( - math_ops.range(0, 5, random_one), - extra_test=lambda s: True, - body=lambda i, s: (s * 10 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (1234,)) + random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32) + s = control_flow.for_stmt( + math_ops.range(0, 5, random_one), + extra_test=lambda s: True, + body=lambda i, s: (s * 10 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor_explicit_limit_delta(self): - with ops.Graph().as_default(): - s = control_flow.for_stmt( - math_ops.range(-17, -3, 5), - extra_test=lambda s: True, - body=lambda i, s: (s * 100 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (-171207,)) + s = control_flow.for_stmt( + math_ops.range(-17, -3, 5), + extra_test=lambda s: True, + body=lambda i, s: (s * 100 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (-171207,)) def test_range_tensor_random_negative_delta(self): - with ops.Graph().as_default(): - random_neg_five = random_ops.random_uniform((), - -5, - -4, - dtype=dtypes.int32) - s = control_flow.for_stmt( - math_ops.range(17, 3, random_neg_five), - extra_test=lambda s: True, - body=lambda i, s: (s * 100 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (171207,)) + random_neg_five = random_ops.random_uniform((), -5, -4, dtype=dtypes.int32) + s = control_flow.for_stmt( + math_ops.range(17, 3, random_neg_five), + extra_test=lambda s: True, + body=lambda i, s: (s * 100 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (171207,)) def test_range_tensor_negative_delta(self): - with ops.Graph().as_default(): - s = control_flow.for_stmt( - math_ops.range(17, 3, -5), - extra_test=lambda s: True, - body=lambda i, s: (s * 100 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(0,)) - self.assertEqual(self.evaluate(s), (171207,)) + s = control_flow.for_stmt( + math_ops.range(17, 3, -5), + extra_test=lambda s: True, + body=lambda i, s: (s * 100 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(0,)) + self.assertEqual(self.evaluate(s), (171207,)) def test_tensor_with_extra_test_only_python_state(self): class MutableObject(object): @@ -151,15 +142,14 @@ class ForLoopTest(test.TestCase): self.assertEqual(s, (1234,)) def test_tf_dataset(self): - with ops.Graph().as_default(): - s = control_flow.for_stmt( - dataset_ops.Dataset.range(5), - extra_test=None, - body=lambda i, s: (s * 10 + i,), - get_state=lambda: (), - set_state=lambda _: None, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) - self.assertEqual(self.evaluate(s), (1234,)) + s = control_flow.for_stmt( + dataset_ops.Dataset.range(5), + extra_test=None, + body=lambda i, s: (s * 10 + i,), + get_state=lambda: (), + set_state=lambda _: None, + init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + self.assertEqual(self.evaluate(s), (1234,)) def test_dataset_with_extra_test(self): s = control_flow.for_stmt( @@ -209,7 +199,6 @@ class ForLoopTest(test.TestCase): init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) self.assertEqual(self.evaluate(s), (3,)) - @test_util.run_v2_only def test_tf_dataset_no_loop_vars(self): v = variables.Variable(0, dtype=dtypes.int64) self.evaluate(v.initializer) @@ -217,7 +206,8 @@ class ForLoopTest(test.TestCase): def stateless_with_side_effects(i): v.assign(v.read_value() * 10 + i) - # function is important here, because ops test for its presence. + # tf.function required for the automatic control dependencies, and because + # ops test for its presence. @def_function.function(autograph=False) def test_fn(): control_flow.for_stmt( @@ -228,7 +218,7 @@ class ForLoopTest(test.TestCase): set_state=lambda _: None, init_vars=()) - test_fn() + self.evaluate(test_fn()) self.assertEqual(self.evaluate(v.read_value()), 1234) def test_tf_iterator(self): @@ -246,14 +236,14 @@ class ForLoopTest(test.TestCase): s, = test_fn() self.assertAllEqual(s, 1234) - @test_util.run_v2_only def test_tf_iterator_no_loop_vars(self): v = variables.Variable(0, dtype=dtypes.int64) + self.evaluate(v.initializer) def stateless_with_side_effects(i): v.assign(v.read_value() * 10 + i) - # graph-mode iterators are only supported inside tf.function. + # tf.function required for the automatic control dependencies. @def_function.function(autograph=False) def test_fn(): control_flow.for_stmt( @@ -264,13 +254,13 @@ class ForLoopTest(test.TestCase): set_state=lambda _: None, init_vars=()) - test_fn() + self.evaluate(test_fn()) self.assertEqual(self.evaluate(v.read_value()), 1234) +@test_util.run_all_in_graph_and_eager_modes class WhileLoopTest(test.TestCase): - @test_util.run_deprecated_v1 def test_tensor(self): n = constant_op.constant(5) results = control_flow.while_stmt( @@ -282,7 +272,6 @@ class WhileLoopTest(test.TestCase): self.assertEqual((5, 10), self.evaluate(results)) def test_tensor_with_tf_side_effects_in_cond(self): - n = constant_op.constant(5, dtype=dtypes.int64) v = variables.Variable(0, dtype=dtypes.int64) @@ -290,7 +279,7 @@ class WhileLoopTest(test.TestCase): v.assign(v.read_value() + 1) return v.read_value() - # function is important here, because ops test for its presence. + # tf.function required for the automatic control dependencies. @def_function.function(autograph=False) def test_fn(): return control_flow.while_stmt( @@ -332,7 +321,6 @@ class WhileLoopTest(test.TestCase): self.assertEqual(self.evaluate(s), (5, 10)) self.assertEqual(self.evaluate(state.field), 10) - @test_util.run_deprecated_v1 def test_python_with_tensor_state(self): n = 5 results = control_flow.while_stmt( @@ -386,47 +374,61 @@ class WhileLoopTest(test.TestCase): out_capturer.getvalue())) +@test_util.run_all_in_graph_and_eager_modes class IfStmtTest(test.TestCase): - def single_return_if_stmt(self, cond): - return control_flow.if_stmt( - cond=cond, - body=lambda: 1, - orelse=lambda: -1, - get_state=lambda: (), - set_state=lambda _: None) - - def multi_return_if_stmt(self, cond): - return control_flow.if_stmt( - cond=cond, - body=lambda: (1, 2), - orelse=lambda: (-1, -2), - get_state=lambda: (), - set_state=lambda _: None) - - @test_util.run_deprecated_v1 def test_tensor(self): - with self.cached_session(): - t = self.single_return_if_stmt(constant_op.constant(True)) - self.assertEqual(1, self.evaluate(t)) - t = self.single_return_if_stmt(constant_op.constant(False)) - self.assertEqual(-1, self.evaluate(t)) + + def test_fn(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: constant_op.constant(1), + orelse=lambda: constant_op.constant(-1), + get_state=lambda: (), + set_state=lambda _: None) + + self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True)))) + self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False)))) + + def test_tensor_multiple_returns(self): + + def test_fn(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: (constant_op.constant(1), constant_op.constant(2)), + orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)), + get_state=lambda: (), + set_state=lambda _: None) + + self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True)))) + self.assertEqual((-1, -2), + self.evaluate(test_fn(constant_op.constant(False)))) def test_python(self): - self.assertEqual(1, self.single_return_if_stmt(True)) - self.assertEqual(-1, self.single_return_if_stmt(False)) - @test_util.run_deprecated_v1 - def test_tensor_multiple_returns(self): - with self.cached_session(): - t = self.multi_return_if_stmt(constant_op.constant(True)) - self.assertAllEqual([1, 2], self.evaluate(t)) - t = self.multi_return_if_stmt(constant_op.constant(False)) - self.assertAllEqual([-1, -2], self.evaluate(t)) + def test_fn(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: 1, + orelse=lambda: -1, + get_state=lambda: (), + set_state=lambda _: None) + + self.assertEqual(1, test_fn(True)) + self.assertEqual(-1, test_fn(False)) def test_python_multiple_returns(self): - self.assertEqual((1, 2), self.multi_return_if_stmt(True)) - self.assertEqual((-1, -2), self.multi_return_if_stmt(False)) + + def test_fn(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: (1, 2), + orelse=lambda: (-1, -2), + get_state=lambda: (), + set_state=lambda _: None) + + self.assertEqual((1, 2), test_fn(True)) + self.assertEqual((-1, -2), test_fn(False)) if __name__ == '__main__':