Handle the case in unconnected gradients where an op has multiple outputs and only some of them having incoming grads.

The unconnected grads logic currently only checks whether there is an entry for the op in the `grads` dict but does not convert the None incoming gradient for a subset of outputs to zeros if requested.

PiperOrigin-RevId: 314634870
Change-Id: Ib6a4b46a600598046603b231a66b9af2af56d5b8
This commit is contained in:
Saurabh Saxena 2020-06-03 17:17:49 -07:00 committed by TensorFlower Gardener
parent 1e3625b72f
commit c05afe921b
2 changed files with 71 additions and 6 deletions

View File

@ -431,6 +431,63 @@ class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
gradients.gradients([y], [x], unconnected_gradients="nonsense")
@parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
unconnected_gradients.UnconnectedGradients.NONE)
def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val):
with ops.Graph().as_default():
# a b
# | |
# IdentityN
# | |
# c d
# |
# Identity
# |
# e
a = constant_op.constant(1.0)
b = constant_op.constant(1.0)
c, d = array_ops.identity_n([a, b])
e = array_ops.identity(c)
# The aggregated grads for the IdentityN node would look like
# [Tensor, None]. We expect this None to be converted to zeros.
output = gradients.gradients(
e, d, unconnected_gradients=unconnected_gradients_val)
if (unconnected_gradients_val ==
unconnected_gradients.UnconnectedGradients.ZERO):
self.assertIsNotNone(output[0])
else:
self.assertIsNone(output[0])
@parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
unconnected_gradients.UnconnectedGradients.NONE)
def testUnconnectedOpWithMultipleOutputsStopGradient(
self, unconnected_gradients_val):
with ops.Graph().as_default():
# a b
# | |
# IdentityN
# | |
# c d
# | |
# SG |
# | |
# \ /
# +
# e
a = constant_op.constant(1.0)
b = constant_op.constant(1.0)
c, d = array_ops.identity_n([a, b])
e = array_ops.stop_gradient(c) + d
# The aggregated grads for the IdentityN node would look like
# [None, Tensor]. We expect this None to be converted to zeros.
output = gradients.gradients(
e, c, unconnected_gradients=unconnected_gradients_val)
if (unconnected_gradients_val ==
unconnected_gradients.UnconnectedGradients.ZERO):
self.assertIsNotNone(output[0])
else:
self.assertIsNone(output[0])
class FunctionGradientsTest(test_util.TensorFlowTestCase):

View File

@ -781,18 +781,22 @@ def _SetGrad(grads, t, grad):
op_grads[t.value_index] = grad
def _ZerosLike(t):
t_dtype = default_gradient.get_zeros_dtype(t)
if t.dtype == dtypes.resource:
return array_ops.zeros(
resource_variable_ops.variable_shape(t), dtype=t_dtype)
else:
return array_ops.zeros_like(t, dtype=t_dtype)
def _GetGrad(grads, t, unconnected_gradients):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
if not op_grads:
if unconnected_gradients == UnconnectedGradients.ZERO:
t_dtype = default_gradient.get_zeros_dtype(t)
if t.dtype == dtypes.resource:
return array_ops.zeros(
resource_variable_ops.variable_shape(t), dtype=t_dtype)
else:
return array_ops.zeros_like(t, dtype=t_dtype)
return _ZerosLike(t)
elif unconnected_gradients == UnconnectedGradients.NONE:
return None
else:
@ -800,6 +804,10 @@ def _GetGrad(grads, t, unconnected_gradients):
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
t_grad = op_grads[t.value_index]
# This can happen if some other output of `t.op` has non-None grad.
if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
return _ZerosLike(t)
assert not isinstance(
t_grad, list), ("gradients list should have been aggregated by now.")
return t_grad