Handle the case in unconnected gradients where an op has multiple outputs and only some of them having incoming grads.
The unconnected grads logic currently only checks whether there is an entry for the op in the `grads` dict but does not convert the None incoming gradient for a subset of outputs to zeros if requested. PiperOrigin-RevId: 314634870 Change-Id: Ib6a4b46a600598046603b231a66b9af2af56d5b8
This commit is contained in:
parent
1e3625b72f
commit
c05afe921b
@ -431,6 +431,63 @@ class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
|
||||
gradients.gradients([y], [x], unconnected_gradients="nonsense")
|
||||
|
||||
@parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
|
||||
unconnected_gradients.UnconnectedGradients.NONE)
|
||||
def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val):
|
||||
with ops.Graph().as_default():
|
||||
# a b
|
||||
# | |
|
||||
# IdentityN
|
||||
# | |
|
||||
# c d
|
||||
# |
|
||||
# Identity
|
||||
# |
|
||||
# e
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(1.0)
|
||||
c, d = array_ops.identity_n([a, b])
|
||||
e = array_ops.identity(c)
|
||||
# The aggregated grads for the IdentityN node would look like
|
||||
# [Tensor, None]. We expect this None to be converted to zeros.
|
||||
output = gradients.gradients(
|
||||
e, d, unconnected_gradients=unconnected_gradients_val)
|
||||
if (unconnected_gradients_val ==
|
||||
unconnected_gradients.UnconnectedGradients.ZERO):
|
||||
self.assertIsNotNone(output[0])
|
||||
else:
|
||||
self.assertIsNone(output[0])
|
||||
|
||||
@parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
|
||||
unconnected_gradients.UnconnectedGradients.NONE)
|
||||
def testUnconnectedOpWithMultipleOutputsStopGradient(
|
||||
self, unconnected_gradients_val):
|
||||
with ops.Graph().as_default():
|
||||
# a b
|
||||
# | |
|
||||
# IdentityN
|
||||
# | |
|
||||
# c d
|
||||
# | |
|
||||
# SG |
|
||||
# | |
|
||||
# \ /
|
||||
# +
|
||||
# e
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(1.0)
|
||||
c, d = array_ops.identity_n([a, b])
|
||||
e = array_ops.stop_gradient(c) + d
|
||||
# The aggregated grads for the IdentityN node would look like
|
||||
# [None, Tensor]. We expect this None to be converted to zeros.
|
||||
output = gradients.gradients(
|
||||
e, c, unconnected_gradients=unconnected_gradients_val)
|
||||
if (unconnected_gradients_val ==
|
||||
unconnected_gradients.UnconnectedGradients.ZERO):
|
||||
self.assertIsNotNone(output[0])
|
||||
else:
|
||||
self.assertIsNone(output[0])
|
||||
|
||||
|
||||
class FunctionGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -781,18 +781,22 @@ def _SetGrad(grads, t, grad):
|
||||
op_grads[t.value_index] = grad
|
||||
|
||||
|
||||
def _ZerosLike(t):
|
||||
t_dtype = default_gradient.get_zeros_dtype(t)
|
||||
if t.dtype == dtypes.resource:
|
||||
return array_ops.zeros(
|
||||
resource_variable_ops.variable_shape(t), dtype=t_dtype)
|
||||
else:
|
||||
return array_ops.zeros_like(t, dtype=t_dtype)
|
||||
|
||||
|
||||
def _GetGrad(grads, t, unconnected_gradients):
|
||||
"""Gets gradient for tensor "t"."""
|
||||
op = t.op
|
||||
op_grads = grads.get(op)
|
||||
if not op_grads:
|
||||
if unconnected_gradients == UnconnectedGradients.ZERO:
|
||||
t_dtype = default_gradient.get_zeros_dtype(t)
|
||||
if t.dtype == dtypes.resource:
|
||||
return array_ops.zeros(
|
||||
resource_variable_ops.variable_shape(t), dtype=t_dtype)
|
||||
else:
|
||||
return array_ops.zeros_like(t, dtype=t_dtype)
|
||||
return _ZerosLike(t)
|
||||
elif unconnected_gradients == UnconnectedGradients.NONE:
|
||||
return None
|
||||
else:
|
||||
@ -800,6 +804,10 @@ def _GetGrad(grads, t, unconnected_gradients):
|
||||
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
||||
|
||||
t_grad = op_grads[t.value_index]
|
||||
# This can happen if some other output of `t.op` has non-None grad.
|
||||
if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
|
||||
return _ZerosLike(t)
|
||||
|
||||
assert not isinstance(
|
||||
t_grad, list), ("gradients list should have been aggregated by now.")
|
||||
return t_grad
|
||||
|
Loading…
Reference in New Issue
Block a user