Merge pull request #31437 from saxenasaurabh/cherrypicks_01FLT

Do not accumulate Const nodes created in forward pass in while_v2.
This commit is contained in:
Goldie Gadde 2019-08-08 10:49:33 -07:00 committed by GitHub
commit ff98617eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 5 deletions

View File

@ -25,10 +25,10 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import random_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
@ -836,13 +836,13 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertLen(while_op.outputs, 3)
gradients_impl.gradients(output, x)
# while_op should have been rewritten to output 2.0 intermediate.
# outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
self.assertLen(while_op.outputs, 5)
# while_op should have been rewritten to output intermediates.
# outputs = [loop_counter, max_iters, x, x_accumulator]
self.assertLen(while_op.outputs, 4)
gradients_impl.gradients(output, x)
# Computing the gradient again shouldn't rewrite while_op again.
self.assertLen(while_op.outputs, 5)
self.assertLen(while_op.outputs, 4)
@test_util.run_deprecated_v1
def testRandomUniformShape(self):
@ -895,6 +895,28 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertAllEqual(ret, 16.)
self.assertAllEqual(grad, 32.)
@test_util.run_deprecated_v1
def testDoNotAccumulateConstNodes(self):
def Body(v):
return v * 2.0
v0 = constant_op.constant(2.)
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
# Gradients computation has the side-effect of updating the forward op
# which is what we want to test.
unused_grad = gradients_impl.gradients(ret, [v0])[0]
# ret is separated from the `While` op by an `Identity` so we skip over
# that.
forward_while_op = ret.op.inputs[0].op
body_graph = while_v2._get_graph(forward_while_op, "body")
push_back_nodes = [
o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
]
# Gradient of `Mul` requires accumulating both its inputs. But since one
# of those is a Const (2.0), we should have just one accumulator.
self.assertLen(push_back_nodes, 1)
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -942,6 +942,17 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
return captured_tensor
# Do not accumulate Const nodes. Instead copy them directly in the backward
# graph.
# TODO(srbs): This just checks for `Const` nodes. Consider checking for
# graph compile time consts in general.
# TODO(srbs): Consider making this a loop input.
if constant_op.is_constant(tensor):
real_value = constant_op.constant(
tensor_util.constant_value(tensor), dtype=tensor.dtype)
self._indirect_captures[ops.tensor_id(tensor)] = real_value
return real_value
# Resource tensors are not accumulated and handled specially.
if tensor.dtype == dtypes.resource:
return self._resource_capture_helper(tensor)