Add test for pruning of unused outputs in while_v2.
PiperOrigin-RevId: 255286547
This commit is contained in:
		
							parent
							
								
									9bfc776f03
								
							
						
					
					
						commit
						6fcb6e8761
					
				@ -3844,6 +3844,7 @@ cuda_py_test(
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:control_flow_ops",
 | 
			
		||||
        "//tensorflow/python:control_flow_util",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:framework",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,8 @@ from tensorflow.core.protobuf import config_pb2
 | 
			
		||||
from tensorflow.core.protobuf import rewriter_config_pb2
 | 
			
		||||
from tensorflow.python.eager import backprop
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.ops import control_flow_ops
 | 
			
		||||
from tensorflow.python.ops import control_flow_util
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
@ -38,7 +40,6 @@ from tensorflow.python.ops import map_fn
 | 
			
		||||
from tensorflow.python.ops import math_ops
 | 
			
		||||
from tensorflow.python.ops import variables
 | 
			
		||||
from tensorflow.python.ops import while_v2
 | 
			
		||||
from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
 | 
			
		||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
 | 
			
		||||
@ -51,7 +52,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
    ret = while_loop_v2(
 | 
			
		||||
        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
 | 
			
		||||
    grad = gradients_impl.gradients(ret, [x])
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertEqual(self.evaluate(ret), 16.)
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad), [32.])
 | 
			
		||||
 | 
			
		||||
@ -131,7 +132,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
 | 
			
		||||
    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad), [9.])
 | 
			
		||||
 | 
			
		||||
@ -157,7 +158,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
    grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
 | 
			
		||||
    grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
 | 
			
		||||
    grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
 | 
			
		||||
@ -189,7 +190,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
        return_same_structure=False)  # x**4
 | 
			
		||||
    grad = gradients_impl.gradients(ret2, [x])  # 4x**3
 | 
			
		||||
    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad), [32.])
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
 | 
			
		||||
 | 
			
		||||
@ -201,13 +202,12 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
        return_same_structure=False)  # x**4
 | 
			
		||||
    grad = gradients_impl.gradients(ret, [x])  # 4x**3
 | 
			
		||||
    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertEqual(self.evaluate(ret), 16.)
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad), [32.])
 | 
			
		||||
      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
 | 
			
		||||
 | 
			
		||||
  @test_util.run_v1_only("b/120545219")
 | 
			
		||||
  def testPruning(self):
 | 
			
		||||
  def _testPruning(self):
 | 
			
		||||
    x = constant_op.constant(1)
 | 
			
		||||
 | 
			
		||||
    tensor_list = list_ops.empty_tensor_list(
 | 
			
		||||
@ -220,7 +220,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
    def Body(x, tl):
 | 
			
		||||
      return x + 1, list_ops.tensor_list_push_back(tl, x)
 | 
			
		||||
 | 
			
		||||
    outputs = while_loop_v1(Cond, Body, [x, tensor_list])
 | 
			
		||||
    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
 | 
			
		||||
 | 
			
		||||
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
 | 
			
		||||
    train_op.append(outputs[0])
 | 
			
		||||
@ -235,12 +235,37 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 | 
			
		||||
      return tf_optimizer.OptimizeGraph(config, mg)
 | 
			
		||||
 | 
			
		||||
    g = GetOptimizedGraph()
 | 
			
		||||
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
 | 
			
		||||
    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
 | 
			
		||||
    # away, causing an extra Enter node.
 | 
			
		||||
    enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
 | 
			
		||||
    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
 | 
			
		||||
    # Test that the TensorList is pruned out.
 | 
			
		||||
    self.assertEmpty([
 | 
			
		||||
        n for n in g.node if n.op == "Enter" and
 | 
			
		||||
        n.attr["T"].type == dtypes.variant.as_datatype_enum
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
 | 
			
		||||
    train_op.append(stack)
 | 
			
		||||
    g = GetOptimizedGraph()
 | 
			
		||||
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
 | 
			
		||||
    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
 | 
			
		||||
    # away, causing an extra Enter node.
 | 
			
		||||
    enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
 | 
			
		||||
    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
 | 
			
		||||
    # Test that the TensorList is not pruned out.
 | 
			
		||||
    self.assertNotEmpty([
 | 
			
		||||
        n for n in g.node if n.op == "Enter" and
 | 
			
		||||
        n.attr["T"].type == dtypes.variant.as_datatype_enum
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
  @test_util.run_deprecated_v1
 | 
			
		||||
  def testPruningV1(self):
 | 
			
		||||
    self._testPruning()
 | 
			
		||||
 | 
			
		||||
  @test_util.enable_control_flow_v2
 | 
			
		||||
  @test_util.run_deprecated_v1
 | 
			
		||||
  def testPruningV2(self):
 | 
			
		||||
    self._testPruning()
 | 
			
		||||
 | 
			
		||||
  @test_util.run_deprecated_v1
 | 
			
		||||
  def testCaptureExternalTensorInCond(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user