Merge pull request #31437 from saxenasaurabh/cherrypicks_01FLT
Do not accumulate Const nodes created in forward pass in while_v2.
This commit is contained in:
commit
ff98617eb0
@ -25,10 +25,10 @@ from tensorflow.core.protobuf import rewriter_config_pb2
|
|||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.ops import control_flow_util_v2
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
from tensorflow.python.ops import control_flow_v2_toggles
|
from tensorflow.python.ops import control_flow_v2_toggles
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
@ -836,13 +836,13 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertLen(while_op.outputs, 3)
|
self.assertLen(while_op.outputs, 3)
|
||||||
|
|
||||||
gradients_impl.gradients(output, x)
|
gradients_impl.gradients(output, x)
|
||||||
# while_op should have been rewritten to output 2.0 intermediate.
|
# while_op should have been rewritten to output intermediates.
|
||||||
# outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
|
# outputs = [loop_counter, max_iters, x, x_accumulator]
|
||||||
self.assertLen(while_op.outputs, 5)
|
self.assertLen(while_op.outputs, 4)
|
||||||
|
|
||||||
gradients_impl.gradients(output, x)
|
gradients_impl.gradients(output, x)
|
||||||
# Computing the gradient again shouldn't rewrite while_op again.
|
# Computing the gradient again shouldn't rewrite while_op again.
|
||||||
self.assertLen(while_op.outputs, 5)
|
self.assertLen(while_op.outputs, 4)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testRandomUniformShape(self):
|
def testRandomUniformShape(self):
|
||||||
@ -895,6 +895,28 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(ret, 16.)
|
self.assertAllEqual(ret, 16.)
|
||||||
self.assertAllEqual(grad, 32.)
|
self.assertAllEqual(grad, 32.)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testDoNotAccumulateConstNodes(self):
|
||||||
|
|
||||||
|
def Body(v):
|
||||||
|
return v * 2.0
|
||||||
|
|
||||||
|
v0 = constant_op.constant(2.)
|
||||||
|
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
|
||||||
|
# Gradients computation has the side-effect of updating the forward op
|
||||||
|
# which is what we want to test.
|
||||||
|
unused_grad = gradients_impl.gradients(ret, [v0])[0]
|
||||||
|
# ret is separated from the `While` op by an `Identity` so we skip over
|
||||||
|
# that.
|
||||||
|
forward_while_op = ret.op.inputs[0].op
|
||||||
|
body_graph = while_v2._get_graph(forward_while_op, "body")
|
||||||
|
push_back_nodes = [
|
||||||
|
o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
|
||||||
|
]
|
||||||
|
# Gradient of `Mul` requires accumulating both its inputs. But since one
|
||||||
|
# of those is a Const (2.0), we should have just one accumulator.
|
||||||
|
self.assertLen(push_back_nodes, 1)
|
||||||
|
|
||||||
|
|
||||||
def ScalarShape():
|
def ScalarShape():
|
||||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||||
|
@ -942,6 +942,17 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
|
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
|
||||||
|
# Do not accumulate Const nodes. Instead copy them directly in the backward
|
||||||
|
# graph.
|
||||||
|
# TODO(srbs): This just checks for `Const` nodes. Consider checking for
|
||||||
|
# graph compile time consts in general.
|
||||||
|
# TODO(srbs): Consider making this a loop input.
|
||||||
|
if constant_op.is_constant(tensor):
|
||||||
|
real_value = constant_op.constant(
|
||||||
|
tensor_util.constant_value(tensor), dtype=tensor.dtype)
|
||||||
|
self._indirect_captures[ops.tensor_id(tensor)] = real_value
|
||||||
|
return real_value
|
||||||
|
|
||||||
# Resource tensors are not accumulated and handled specially.
|
# Resource tensors are not accumulated and handled specially.
|
||||||
if tensor.dtype == dtypes.resource:
|
if tensor.dtype == dtypes.resource:
|
||||||
return self._resource_capture_helper(tensor)
|
return self._resource_capture_helper(tensor)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user