diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index a378a150d37..cc1fa819dac 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3844,6 +3844,7 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:dtypes", "//tensorflow/python:framework", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 040c51c5415..48c2ad74310 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -24,6 +24,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -38,7 +40,6 @@ from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.ops import while_v2 -from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1 from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 from tensorflow.python.platform import test @@ -51,7 +52,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): ret = while_loop_v2( lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) grad = gradients_impl.gradients(ret, [x]) - with self.cached_session() as sess: + with self.cached_session(): self.assertEqual(self.evaluate(ret), 16.) self.assertSequenceEqual(self.evaluate(grad), [32.]) @@ -131,7 +132,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. grad = gradients_impl.gradients(ret, [x]) # [2*x*y] - with self.cached_session() as sess: + with self.cached_session(): self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) self.assertSequenceEqual(self.evaluate(grad), [9.]) @@ -157,7 +158,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] - with self.cached_session() as sess: + with self.cached_session(): self.assertSequenceEqual(self.evaluate(ret), [120., 23.]) self.assertSequenceEqual(self.evaluate(gradx_0), [39.]) self.assertSequenceEqual(self.evaluate(gradx_1), [4.]) @@ -189,7 +190,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): return_same_structure=False) # x**4 grad = gradients_impl.gradients(ret2, [x]) # 4x**3 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 - with self.cached_session() as sess: + with self.cached_session(): self.assertSequenceEqual(self.evaluate(grad), [32.]) self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) @@ -201,13 +202,12 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): return_same_structure=False) # x**4 grad = gradients_impl.gradients(ret, [x]) # 4x**3 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 - with self.cached_session() as sess: + with self.cached_session(): self.assertEqual(self.evaluate(ret), 16.) self.assertSequenceEqual(self.evaluate(grad), [32.]) self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) - @test_util.run_v1_only("b/120545219") - def testPruning(self): + def _testPruning(self): x = constant_op.constant(1) tensor_list = list_ops.empty_tensor_list( @@ -220,7 +220,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x) - outputs = while_loop_v1(Cond, Body, [x, tensor_list]) + outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) @@ -235,12 +235,37 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): return tf_optimizer.OptimizeGraph(config, mg) g = GetOptimizedGraph() - self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) + # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned + # away, causing an extra Enter node. + enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 + self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) + # Test that the TensorList is pruned out. + self.assertEmpty([ + n for n in g.node if n.op == "Enter" and + n.attr["T"].type == dtypes.variant.as_datatype_enum + ]) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() - self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2) + # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned + # away, causing an extra Enter node. + enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 + self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) + # Test that the TensorList is not pruned out. + self.assertNotEmpty([ + n for n in g.node if n.op == "Enter" and + n.attr["T"].type == dtypes.variant.as_datatype_enum + ]) + + @test_util.run_deprecated_v1 + def testPruningV1(self): + self._testPruning() + + @test_util.enable_control_flow_v2 + @test_util.run_deprecated_v1 + def testPruningV2(self): + self._testPruning() @test_util.run_deprecated_v1 def testCaptureExternalTensorInCond(self):