Merge pull request #31437 from saxenasaurabh/cherrypicks_01FLT
Do not accumulate Const nodes created in forward pass in while_v2.
This commit is contained in:
commit
ff98617eb0
@ -25,10 +25,10 @@ from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import meta_graph
|
||||
@ -836,13 +836,13 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(while_op.outputs, 3)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# while_op should have been rewritten to output 2.0 intermediate.
|
||||
# outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
# while_op should have been rewritten to output intermediates.
|
||||
# outputs = [loop_counter, max_iters, x, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# Computing the gradient again shouldn't rewrite while_op again.
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRandomUniformShape(self):
|
||||
@ -895,6 +895,28 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(ret, 16.)
|
||||
self.assertAllEqual(grad, 32.)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDoNotAccumulateConstNodes(self):
|
||||
|
||||
def Body(v):
|
||||
return v * 2.0
|
||||
|
||||
v0 = constant_op.constant(2.)
|
||||
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
|
||||
# Gradients computation has the side-effect of updating the forward op
|
||||
# which is what we want to test.
|
||||
unused_grad = gradients_impl.gradients(ret, [v0])[0]
|
||||
# ret is separated from the `While` op by an `Identity` so we skip over
|
||||
# that.
|
||||
forward_while_op = ret.op.inputs[0].op
|
||||
body_graph = while_v2._get_graph(forward_while_op, "body")
|
||||
push_back_nodes = [
|
||||
o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
|
||||
]
|
||||
# Gradient of `Mul` requires accumulating both its inputs. But since one
|
||||
# of those is a Const (2.0), we should have just one accumulator.
|
||||
self.assertLen(push_back_nodes, 1)
|
||||
|
||||
|
||||
def ScalarShape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
@ -942,6 +942,17 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
|
||||
return captured_tensor
|
||||
|
||||
# Do not accumulate Const nodes. Instead copy them directly in the backward
|
||||
# graph.
|
||||
# TODO(srbs): This just checks for `Const` nodes. Consider checking for
|
||||
# graph compile time consts in general.
|
||||
# TODO(srbs): Consider making this a loop input.
|
||||
if constant_op.is_constant(tensor):
|
||||
real_value = constant_op.constant(
|
||||
tensor_util.constant_value(tensor), dtype=tensor.dtype)
|
||||
self._indirect_captures[ops.tensor_id(tensor)] = real_value
|
||||
return real_value
|
||||
|
||||
# Resource tensors are not accumulated and handled specially.
|
||||
if tensor.dtype == dtypes.resource:
|
||||
return self._resource_capture_helper(tensor)
|
||||
|
Loading…
x
Reference in New Issue
Block a user