Add option in tf.gradients() to return zero tensors for unconnected gradients.

tf.gradients currently returns [NONE] when the gradient of unconnected variables
is required. This backwards compatable change adds in the option to have zero
tensors returned that match the dimensions of the input tensor.

PiperOrigin-RevId: 215725488
This commit is contained in:
A. Unique TensorFlower 2018-10-04 06:09:42 -07:00 committed by TensorFlower Gardener
parent 2c9369c8d8
commit 82ea80b979
8 changed files with 135 additions and 7 deletions

View File

@ -2152,6 +2152,7 @@ py_library(
":array_grad",
":array_ops",
":bitwise_ops",
":check_ops",
":cond_v2_impl",
":control_flow_grad",
":control_flow_ops",
@ -2172,8 +2173,11 @@ py_library(
":random_grad",
":resource_variable_ops",
":spectral_grad",
":tensor_array_ops",
":tensor_util",
":util",
":variable_scope",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",

View File

@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
from tensorflow.python.ops.gradients_impl import UnconnectedGradients
# pylint: enable=unused-import

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import contextlib
import enum # pylint: disable=g-bad-import-order
import sys
import warnings
@ -537,6 +538,26 @@ def _Consumers(t, func_graphs):
return consumers
@tf_export("UnconnectedGradients")
class UnconnectedGradients(enum.Enum):
"""Controls how gradient computation behaves when y does not depend on x.
The gradient of y with respect to x can be zero in two different ways: there
could be no differentiable path in the graph connecting x to y (and so we can
statically prove that the gradient is zero) or it could be that runtime values
of tensors in a particular execution lead to a gradient of zero (say, if a
relu unit happens to not be activated). To allow you to distinguish between
these two cases you can choose what value gets returned for the gradient when
there is no path in the graph from x to y:
* `NONE`: Indicates that [None] will be returned if there is no path from x
to y
* `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
"""
NONE = "none"
ZERO = "zero"
@tf_export("gradients")
def gradients(ys,
xs,
@ -545,7 +566,8 @@ def gradients(ys,
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None):
stop_gradients=None,
unconnected_gradients=UnconnectedGradients.NONE):
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
@ -596,6 +618,23 @@ def gradients(ys,
All integer tensors are considered constant with respect to all `xs`, as if
they were included in `stop_gradients`.
`unconnected_gradients` determines the value returned for each x in xs if it
is unconnected in the graph to ys. By default this is None to safeguard
against errors. MAthematically these gradients are zero which can be requested
using the `'zero'` option. `tf.UnconnectedGradients` provides the
following options and behaviors:
```python
a = tf.ones([1, 2])
b = tf.ones([3, 1])
g1 = tf.gradients([b], [a], unnconnected_gradients='none')
sess.run(g1) # [None]
g2 = tf.gradients([b], [a], unconnected_gradients='zero')
sess.run(g2) # [array([[0., 0.]], dtype=float32)]
```
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
@ -611,6 +650,10 @@ def gradients(ys,
Accepted values are constants defined in the class `AggregationMethod`.
stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
through.
unconnected_gradients: Optional. Specifies the gradient value returned when
the given input tensors are unconnected. Accepted values are constants
defined in the class `tf.UnconnectedGradients` and the default value is
`none`.
Returns:
A list of `sum(dy/dx)` for each x in `xs`.
@ -627,7 +670,8 @@ def gradients(ys,
# mutating new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients)
gate_gradients, aggregation_method, stop_gradients,
unconnected_gradients)
def _GradientsHelper(ys,
@ -638,6 +682,7 @@ def _GradientsHelper(ys,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
unconnected_gradients=UnconnectedGradients.NONE,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
@ -645,6 +690,11 @@ def _GradientsHelper(ys,
"is enabled. Use tf.GradientTape instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
try:
unconnected_gradients = UnconnectedGradients(unconnected_gradients)
except ValueError:
raise ValueError(
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
# If src_graph is a _FuncGraph (i.e. a function body), gather it and all
# ancestor graphs. This is necessary for correctly handling captured values.
@ -856,7 +906,7 @@ def _GradientsHelper(ys,
if loop_state:
loop_state.PostProcessing()
return [_GetGrad(grads, x) for x in xs]
return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
def _HasAnyNotNoneGrads(grads, op):
@ -924,12 +974,19 @@ def _SetGrad(grads, t, grad):
op_grads[t.value_index] = grad
def _GetGrad(grads, t):
def _GetGrad(grads, t, unconnected_gradients):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
if not op_grads:
if unconnected_gradients == UnconnectedGradients.ZERO:
return array_ops.zeros_like(t)
elif unconnected_gradients == UnconnectedGradients.NONE:
return None
else:
raise ValueError(
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
t_grad = op_grads[t.value_index]
assert not isinstance(
t_grad, list), ("gradients list should have been aggregated by now.")

View File

@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase):
for a, b in zip(npgrad1, npgrad2):
np.testing.assert_allclose(a, b)
def testUnconnectedGradientsNoneUnconnectedGradients(self):
with ops.Graph().as_default():
x = constant(1.0, shape=[2, 2])
y = constant(3.0, shape=[3, 1])
grad = gradients.gradients(
[y], [x], unconnected_gradients="none")
self.assertIsNone(grad[0])
def testUnconnectedGradientsZerosUnconnectedGradients(self):
with ops.Graph().as_default():
x = constant(1.0, shape=[2, 2])
y = constant(3.0, shape=[3, 1])
grads = gradients.gradients(
[y], [x], unconnected_gradients="zero")
with self.cached_session() as sess:
self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0])
def testUnconnectedGradientsZeroConnectedGradients(self):
with ops.Graph().as_default():
x = constant(1.0)
y = x * 3.0
grad = gradients.gradients(
[y], [x], unconnected_gradients="zero")
with self.cached_session() as sess:
self.assertEquals(3.0, sess.run(grad)[0])
def testUnknownUnconnectedGradientsValueGiven(self):
with ops.Graph().as_default():
x = constant(1.0)
y = constant(1.0)
with self.assertRaisesRegexp(
ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
gradients.gradients([y], [x], unconnected_gradients="nonsense")
class FunctionGradientsTest(test_util.TensorFlowTestCase):

View File

@ -0,0 +1,12 @@
path: "tensorflow.UnconnectedGradients"
tf_class {
is_instance: "<enum \'UnconnectedGradients\'>"
member {
name: "NONE"
mtype: "<enum \'UnconnectedGradients\'>"
}
member {
name: "ZERO"
mtype: "<enum \'UnconnectedGradients\'>"
}
}

View File

@ -248,6 +248,10 @@ tf_module {
name: "TextLineReader"
mtype: "<type \'type\'>"
}
member {
name: "UnconnectedGradients"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "VERSION"
mtype: "<type \'str\'>"
@ -1234,7 +1238,7 @@ tf_module {
}
member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "greater"

View File

@ -0,0 +1,12 @@
path: "tensorflow.UnconnectedGradients"
tf_class {
is_instance: "<enum \'UnconnectedGradients\'>"
member {
name: "NONE"
mtype: "<enum \'UnconnectedGradients\'>"
}
member {
name: "ZERO"
mtype: "<enum \'UnconnectedGradients\'>"
}
}

View File

@ -220,6 +220,10 @@ tf_module {
name: "TensorShape"
mtype: "<type \'type\'>"
}
member {
name: "UnconnectedGradients"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "VERSION"
mtype: "<type \'str\'>"
@ -1134,7 +1138,7 @@ tf_module {
}
member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "greater"