Add test for pruning of unused outputs in while_v2.

PiperOrigin-RevId: 255286547
This commit is contained in:
Saurabh Saxena 2019-06-26 16:19:31 -07:00 committed by TensorFlower Gardener
parent 9bfc776f03
commit 6fcb6e8761
2 changed files with 37 additions and 11 deletions

View File

@ -3844,6 +3844,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",

View File

@ -24,6 +24,8 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -38,7 +40,6 @@ from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_v2
from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
from tensorflow.python.platform import test
@ -51,7 +52,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
ret = while_loop_v2(
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session() as sess:
with self.cached_session():
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@ -131,7 +132,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
with self.cached_session() as sess:
with self.cached_session():
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
self.assertSequenceEqual(self.evaluate(grad), [9.])
@ -157,7 +158,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
with self.cached_session() as sess:
with self.cached_session():
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
@ -189,7 +190,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.cached_session() as sess:
with self.cached_session():
self.assertSequenceEqual(self.evaluate(grad), [32.])
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
@ -201,13 +202,12 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.cached_session() as sess:
with self.cached_session():
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
@test_util.run_v1_only("b/120545219")
def testPruning(self):
def _testPruning(self):
x = constant_op.constant(1)
tensor_list = list_ops.empty_tensor_list(
@ -220,7 +220,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
def Body(x, tl):
return x + 1, list_ops.tensor_list_push_back(tl, x)
outputs = while_loop_v1(Cond, Body, [x, tensor_list])
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(outputs[0])
@ -235,12 +235,37 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
return tf_optimizer.OptimizeGraph(config, mg)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is pruned out.
self.assertEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
train_op.append(stack)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is not pruned out.
self.assertNotEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
@test_util.run_deprecated_v1
def testPruningV1(self):
self._testPruning()
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
def testPruningV2(self):
self._testPruning()
@test_util.run_deprecated_v1
def testCaptureExternalTensorInCond(self):