Add option in tf.gradients() to return zero tensors for unconnected gradients.
tf.gradients currently returns [NONE] when the gradient of unconnected variables is required. This backwards compatable change adds in the option to have zero tensors returned that match the dimensions of the input tensor. PiperOrigin-RevId: 215725488
This commit is contained in:
parent
2c9369c8d8
commit
82ea80b979
@ -2152,6 +2152,7 @@ py_library(
|
||||
":array_grad",
|
||||
":array_ops",
|
||||
":bitwise_ops",
|
||||
":check_ops",
|
||||
":cond_v2_impl",
|
||||
":control_flow_grad",
|
||||
":control_flow_ops",
|
||||
@ -2172,8 +2173,11 @@ py_library(
|
||||
":random_grad",
|
||||
":resource_variable_ops",
|
||||
":spectral_grad",
|
||||
":tensor_array_ops",
|
||||
":tensor_util",
|
||||
":util",
|
||||
":variable_scope",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:tape",
|
||||
|
@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
|
||||
from tensorflow.python.ops.gradients_impl import AggregationMethod
|
||||
from tensorflow.python.ops.gradients_impl import gradients
|
||||
from tensorflow.python.ops.gradients_impl import hessians
|
||||
from tensorflow.python.ops.gradients_impl import UnconnectedGradients
|
||||
# pylint: enable=unused-import
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
@ -537,6 +538,26 @@ def _Consumers(t, func_graphs):
|
||||
return consumers
|
||||
|
||||
|
||||
@tf_export("UnconnectedGradients")
|
||||
class UnconnectedGradients(enum.Enum):
|
||||
"""Controls how gradient computation behaves when y does not depend on x.
|
||||
|
||||
The gradient of y with respect to x can be zero in two different ways: there
|
||||
could be no differentiable path in the graph connecting x to y (and so we can
|
||||
statically prove that the gradient is zero) or it could be that runtime values
|
||||
of tensors in a particular execution lead to a gradient of zero (say, if a
|
||||
relu unit happens to not be activated). To allow you to distinguish between
|
||||
these two cases you can choose what value gets returned for the gradient when
|
||||
there is no path in the graph from x to y:
|
||||
|
||||
* `NONE`: Indicates that [None] will be returned if there is no path from x
|
||||
to y
|
||||
* `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
|
||||
"""
|
||||
NONE = "none"
|
||||
ZERO = "zero"
|
||||
|
||||
|
||||
@tf_export("gradients")
|
||||
def gradients(ys,
|
||||
xs,
|
||||
@ -545,7 +566,8 @@ def gradients(ys,
|
||||
colocate_gradients_with_ops=False,
|
||||
gate_gradients=False,
|
||||
aggregation_method=None,
|
||||
stop_gradients=None):
|
||||
stop_gradients=None,
|
||||
unconnected_gradients=UnconnectedGradients.NONE):
|
||||
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
|
||||
|
||||
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
|
||||
@ -596,6 +618,23 @@ def gradients(ys,
|
||||
All integer tensors are considered constant with respect to all `xs`, as if
|
||||
they were included in `stop_gradients`.
|
||||
|
||||
`unconnected_gradients` determines the value returned for each x in xs if it
|
||||
is unconnected in the graph to ys. By default this is None to safeguard
|
||||
against errors. MAthematically these gradients are zero which can be requested
|
||||
using the `'zero'` option. `tf.UnconnectedGradients` provides the
|
||||
following options and behaviors:
|
||||
|
||||
```python
|
||||
a = tf.ones([1, 2])
|
||||
b = tf.ones([3, 1])
|
||||
g1 = tf.gradients([b], [a], unnconnected_gradients='none')
|
||||
sess.run(g1) # [None]
|
||||
|
||||
g2 = tf.gradients([b], [a], unconnected_gradients='zero')
|
||||
sess.run(g2) # [array([[0., 0.]], dtype=float32)]
|
||||
```
|
||||
|
||||
|
||||
Args:
|
||||
ys: A `Tensor` or list of tensors to be differentiated.
|
||||
xs: A `Tensor` or list of tensors to be used for differentiation.
|
||||
@ -611,6 +650,10 @@ def gradients(ys,
|
||||
Accepted values are constants defined in the class `AggregationMethod`.
|
||||
stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
|
||||
through.
|
||||
unconnected_gradients: Optional. Specifies the gradient value returned when
|
||||
the given input tensors are unconnected. Accepted values are constants
|
||||
defined in the class `tf.UnconnectedGradients` and the default value is
|
||||
`none`.
|
||||
|
||||
Returns:
|
||||
A list of `sum(dy/dx)` for each x in `xs`.
|
||||
@ -627,7 +670,8 @@ def gradients(ys,
|
||||
# mutating new ops.
|
||||
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
|
||||
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
gate_gradients, aggregation_method, stop_gradients)
|
||||
gate_gradients, aggregation_method, stop_gradients,
|
||||
unconnected_gradients)
|
||||
|
||||
|
||||
def _GradientsHelper(ys,
|
||||
@ -638,6 +682,7 @@ def _GradientsHelper(ys,
|
||||
gate_gradients=False,
|
||||
aggregation_method=None,
|
||||
stop_gradients=None,
|
||||
unconnected_gradients=UnconnectedGradients.NONE,
|
||||
src_graph=None):
|
||||
"""Implementation of gradients()."""
|
||||
if context.executing_eagerly():
|
||||
@ -645,6 +690,11 @@ def _GradientsHelper(ys,
|
||||
"is enabled. Use tf.GradientTape instead.")
|
||||
if src_graph is None:
|
||||
src_graph = ops.get_default_graph()
|
||||
try:
|
||||
unconnected_gradients = UnconnectedGradients(unconnected_gradients)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
||||
|
||||
# If src_graph is a _FuncGraph (i.e. a function body), gather it and all
|
||||
# ancestor graphs. This is necessary for correctly handling captured values.
|
||||
@ -856,7 +906,7 @@ def _GradientsHelper(ys,
|
||||
|
||||
if loop_state:
|
||||
loop_state.PostProcessing()
|
||||
return [_GetGrad(grads, x) for x in xs]
|
||||
return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
|
||||
|
||||
|
||||
def _HasAnyNotNoneGrads(grads, op):
|
||||
@ -924,12 +974,19 @@ def _SetGrad(grads, t, grad):
|
||||
op_grads[t.value_index] = grad
|
||||
|
||||
|
||||
def _GetGrad(grads, t):
|
||||
def _GetGrad(grads, t, unconnected_gradients):
|
||||
"""Gets gradient for tensor "t"."""
|
||||
op = t.op
|
||||
op_grads = grads.get(op)
|
||||
if not op_grads:
|
||||
if unconnected_gradients == UnconnectedGradients.ZERO:
|
||||
return array_ops.zeros_like(t)
|
||||
elif unconnected_gradients == UnconnectedGradients.NONE:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
||||
|
||||
t_grad = op_grads[t.value_index]
|
||||
assert not isinstance(
|
||||
t_grad, list), ("gradients list should have been aggregated by now.")
|
||||
|
@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase):
|
||||
for a, b in zip(npgrad1, npgrad2):
|
||||
np.testing.assert_allclose(a, b)
|
||||
|
||||
def testUnconnectedGradientsNoneUnconnectedGradients(self):
|
||||
with ops.Graph().as_default():
|
||||
x = constant(1.0, shape=[2, 2])
|
||||
y = constant(3.0, shape=[3, 1])
|
||||
grad = gradients.gradients(
|
||||
[y], [x], unconnected_gradients="none")
|
||||
self.assertIsNone(grad[0])
|
||||
|
||||
def testUnconnectedGradientsZerosUnconnectedGradients(self):
|
||||
with ops.Graph().as_default():
|
||||
x = constant(1.0, shape=[2, 2])
|
||||
y = constant(3.0, shape=[3, 1])
|
||||
grads = gradients.gradients(
|
||||
[y], [x], unconnected_gradients="zero")
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0])
|
||||
|
||||
def testUnconnectedGradientsZeroConnectedGradients(self):
|
||||
with ops.Graph().as_default():
|
||||
x = constant(1.0)
|
||||
y = x * 3.0
|
||||
grad = gradients.gradients(
|
||||
[y], [x], unconnected_gradients="zero")
|
||||
with self.cached_session() as sess:
|
||||
self.assertEquals(3.0, sess.run(grad)[0])
|
||||
|
||||
def testUnknownUnconnectedGradientsValueGiven(self):
|
||||
with ops.Graph().as_default():
|
||||
x = constant(1.0)
|
||||
y = constant(1.0)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
|
||||
gradients.gradients([y], [x], unconnected_gradients="nonsense")
|
||||
|
||||
|
||||
class FunctionGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.UnconnectedGradients"
|
||||
tf_class {
|
||||
is_instance: "<enum \'UnconnectedGradients\'>"
|
||||
member {
|
||||
name: "NONE"
|
||||
mtype: "<enum \'UnconnectedGradients\'>"
|
||||
}
|
||||
member {
|
||||
name: "ZERO"
|
||||
mtype: "<enum \'UnconnectedGradients\'>"
|
||||
}
|
||||
}
|
@ -248,6 +248,10 @@ tf_module {
|
||||
name: "TextLineReader"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "UnconnectedGradients"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "VERSION"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -1234,7 +1238,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "gradients"
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "greater"
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.UnconnectedGradients"
|
||||
tf_class {
|
||||
is_instance: "<enum \'UnconnectedGradients\'>"
|
||||
member {
|
||||
name: "NONE"
|
||||
mtype: "<enum \'UnconnectedGradients\'>"
|
||||
}
|
||||
member {
|
||||
name: "ZERO"
|
||||
mtype: "<enum \'UnconnectedGradients\'>"
|
||||
}
|
||||
}
|
@ -220,6 +220,10 @@ tf_module {
|
||||
name: "TensorShape"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "UnconnectedGradients"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "VERSION"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -1134,7 +1138,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "gradients"
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "greater"
|
||||
|
Loading…
Reference in New Issue
Block a user