From 9c3422db3bfb7923ef50c2280b2100dad465ff5c Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Tue, 6 Aug 2019 11:31:24 -0700 Subject: [PATCH] Do not accumulate Const nodes created in forward pass in while_v2. PiperOrigin-RevId: 261958798 --- .../python/kernel_tests/while_v2_test.py | 32 ++++++++++++++++--- tensorflow/python/ops/while_v2.py | 11 +++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 3222f4b14f4..3d91f7ddb7e 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -25,10 +25,10 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.compat import compat from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import random_ops -from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph @@ -836,13 +836,13 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertLen(while_op.outputs, 3) gradients_impl.gradients(output, x) - # while_op should have been rewritten to output 2.0 intermediate. - # outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator] - self.assertLen(while_op.outputs, 5) + # while_op should have been rewritten to output intermediates. + # outputs = [loop_counter, max_iters, x, x_accumulator] + self.assertLen(while_op.outputs, 4) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite while_op again. - self.assertLen(while_op.outputs, 5) + self.assertLen(while_op.outputs, 4) @test_util.run_deprecated_v1 def testRandomUniformShape(self): @@ -895,6 +895,28 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertAllEqual(ret, 16.) self.assertAllEqual(grad, 32.) + @test_util.run_deprecated_v1 + def testDoNotAccumulateConstNodes(self): + + def Body(v): + return v * 2.0 + + v0 = constant_op.constant(2.) + ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] + # Gradients computation has the side-effect of updating the forward op + # which is what we want to test. + unused_grad = gradients_impl.gradients(ret, [v0])[0] + # ret is separated from the `While` op by an `Identity` so we skip over + # that. + forward_while_op = ret.op.inputs[0].op + body_graph = while_v2._get_graph(forward_while_op, "body") + push_back_nodes = [ + o for o in body_graph.get_operations() if o.type == "TensorListPushBack" + ] + # Gradient of `Mul` requires accumulating both its inputs. But since one + # of those is a Const (2.0), we should have just one accumulator. + self.assertLen(push_back_nodes, 1) + def ScalarShape(): return ops.convert_to_tensor([], dtype=dtypes.int32) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index b32986543fd..5f2a39d5fd8 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -942,6 +942,17 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor return captured_tensor + # Do not accumulate Const nodes. Instead copy them directly in the backward + # graph. + # TODO(srbs): This just checks for `Const` nodes. Consider checking for + # graph compile time consts in general. + # TODO(srbs): Consider making this a loop input. + if constant_op.is_constant(tensor): + real_value = constant_op.constant( + tensor_util.constant_value(tensor), dtype=tensor.dtype) + self._indirect_captures[ops.tensor_id(tensor)] = real_value + return real_value + # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor)