Add test for pruning of unused outputs in while_v2.
PiperOrigin-RevId: 255286547
This commit is contained in:
parent
9bfc776f03
commit
6fcb6e8761
@ -3844,6 +3844,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -24,6 +24,8 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -38,7 +40,6 @@ from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import while_v2
|
||||
from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
|
||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -51,7 +52,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
ret = while_loop_v2(
|
||||
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
|
||||
grad = gradients_impl.gradients(ret, [x])
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
self.assertEqual(self.evaluate(ret), 16.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
|
||||
@ -131,7 +132,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
||||
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
|
||||
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
||||
|
||||
@ -157,7 +158,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
||||
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
||||
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
|
||||
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
|
||||
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
|
||||
@ -189,7 +190,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
return_same_structure=False) # x**4
|
||||
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||
|
||||
@ -201,13 +202,12 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
return_same_structure=False) # x**4
|
||||
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
self.assertEqual(self.evaluate(ret), 16.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPruning(self):
|
||||
def _testPruning(self):
|
||||
x = constant_op.constant(1)
|
||||
|
||||
tensor_list = list_ops.empty_tensor_list(
|
||||
@ -220,7 +220,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
def Body(x, tl):
|
||||
return x + 1, list_ops.tensor_list_push_back(tl, x)
|
||||
|
||||
outputs = while_loop_v1(Cond, Body, [x, tensor_list])
|
||||
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
|
||||
|
||||
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||
train_op.append(outputs[0])
|
||||
@ -235,12 +235,37 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
return tf_optimizer.OptimizeGraph(config, mg)
|
||||
|
||||
g = GetOptimizedGraph()
|
||||
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
|
||||
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
||||
# away, causing an extra Enter node.
|
||||
enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
|
||||
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
||||
# Test that the TensorList is pruned out.
|
||||
self.assertEmpty([
|
||||
n for n in g.node if n.op == "Enter" and
|
||||
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
||||
])
|
||||
|
||||
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
|
||||
train_op.append(stack)
|
||||
g = GetOptimizedGraph()
|
||||
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
|
||||
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
||||
# away, causing an extra Enter node.
|
||||
enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
|
||||
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
||||
# Test that the TensorList is not pruned out.
|
||||
self.assertNotEmpty([
|
||||
n for n in g.node if n.op == "Enter" and
|
||||
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
||||
])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testPruningV1(self):
|
||||
self._testPruning()
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
@test_util.run_deprecated_v1
|
||||
def testPruningV2(self):
|
||||
self._testPruning()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCaptureExternalTensorInCond(self):
|
||||
|
Loading…
Reference in New Issue
Block a user