Add test for pruning of unused outputs in while_v2.
PiperOrigin-RevId: 255286547
This commit is contained in:
parent
9bfc776f03
commit
6fcb6e8761
@ -3844,6 +3844,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:control_flow_util",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
@ -24,6 +24,8 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -38,7 +40,6 @@ from tensorflow.python.ops import map_fn
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops import while_v2
|
from tensorflow.python.ops import while_v2
|
||||||
from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
|
|
||||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
ret = while_loop_v2(
|
ret = while_loop_v2(
|
||||||
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
|
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
|
||||||
grad = gradients_impl.gradients(ret, [x])
|
grad = gradients_impl.gradients(ret, [x])
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
self.assertEqual(self.evaluate(ret), 16.)
|
self.assertEqual(self.evaluate(ret), 16.)
|
||||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||||
|
|
||||||
@ -131,7 +132,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
||||||
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
|
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
|
||||||
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
||||||
|
|
||||||
@ -157,7 +158,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
||||||
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
||||||
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
|
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
|
||||||
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
|
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
|
||||||
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
|
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
|
||||||
@ -189,7 +190,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
return_same_structure=False) # x**4
|
return_same_structure=False) # x**4
|
||||||
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
||||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||||
|
|
||||||
@ -201,13 +202,12 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
return_same_structure=False) # x**4
|
return_same_structure=False) # x**4
|
||||||
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
||||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
self.assertEqual(self.evaluate(ret), 16.)
|
self.assertEqual(self.evaluate(ret), 16.)
|
||||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
def _testPruning(self):
|
||||||
def testPruning(self):
|
|
||||||
x = constant_op.constant(1)
|
x = constant_op.constant(1)
|
||||||
|
|
||||||
tensor_list = list_ops.empty_tensor_list(
|
tensor_list = list_ops.empty_tensor_list(
|
||||||
@ -220,7 +220,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
def Body(x, tl):
|
def Body(x, tl):
|
||||||
return x + 1, list_ops.tensor_list_push_back(tl, x)
|
return x + 1, list_ops.tensor_list_push_back(tl, x)
|
||||||
|
|
||||||
outputs = while_loop_v1(Cond, Body, [x, tensor_list])
|
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
|
||||||
|
|
||||||
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||||
train_op.append(outputs[0])
|
train_op.append(outputs[0])
|
||||||
@ -235,12 +235,37 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
return tf_optimizer.OptimizeGraph(config, mg)
|
return tf_optimizer.OptimizeGraph(config, mg)
|
||||||
|
|
||||||
g = GetOptimizedGraph()
|
g = GetOptimizedGraph()
|
||||||
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
||||||
|
# away, causing an extra Enter node.
|
||||||
|
enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
|
||||||
|
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
||||||
|
# Test that the TensorList is pruned out.
|
||||||
|
self.assertEmpty([
|
||||||
|
n for n in g.node if n.op == "Enter" and
|
||||||
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
||||||
|
])
|
||||||
|
|
||||||
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
|
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
|
||||||
train_op.append(stack)
|
train_op.append(stack)
|
||||||
g = GetOptimizedGraph()
|
g = GetOptimizedGraph()
|
||||||
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
||||||
|
# away, causing an extra Enter node.
|
||||||
|
enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
|
||||||
|
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
||||||
|
# Test that the TensorList is not pruned out.
|
||||||
|
self.assertNotEmpty([
|
||||||
|
n for n in g.node if n.op == "Enter" and
|
||||||
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
||||||
|
])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testPruningV1(self):
|
||||||
|
self._testPruning()
|
||||||
|
|
||||||
|
@test_util.enable_control_flow_v2
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testPruningV2(self):
|
||||||
|
self._testPruning()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testCaptureExternalTensorInCond(self):
|
def testCaptureExternalTensorInCond(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user