diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index fe81254ef79..da3c56db923 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2152,6 +2152,7 @@ py_library( ":array_grad", ":array_ops", ":bitwise_ops", + ":check_ops", ":cond_v2_impl", ":control_flow_grad", ":control_flow_ops", @@ -2172,8 +2173,11 @@ py_library( ":random_grad", ":resource_variable_ops", ":spectral_grad", + ":tensor_array_ops", + ":tensor_util", ":util", ":variable_scope", + "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 1dc666e78b2..794465b10e3 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.gradients_impl import AggregationMethod from tensorflow.python.ops.gradients_impl import gradients from tensorflow.python.ops.gradients_impl import hessians +from tensorflow.python.ops.gradients_impl import UnconnectedGradients # pylint: enable=unused-import diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 056015d6b6a..aac95037dce 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import contextlib +import enum # pylint: disable=g-bad-import-order import sys import warnings @@ -537,6 +538,26 @@ def _Consumers(t, func_graphs): return consumers +@tf_export("UnconnectedGradients") +class UnconnectedGradients(enum.Enum): + """Controls how gradient computation behaves when y does not depend on x. + + The gradient of y with respect to x can be zero in two different ways: there + could be no differentiable path in the graph connecting x to y (and so we can + statically prove that the gradient is zero) or it could be that runtime values + of tensors in a particular execution lead to a gradient of zero (say, if a + relu unit happens to not be activated). To allow you to distinguish between + these two cases you can choose what value gets returned for the gradient when + there is no path in the graph from x to y: + + * `NONE`: Indicates that [None] will be returned if there is no path from x + to y + * `ZERO`: Indicates that a zero tensor will be returned in the shape of x. + """ + NONE = "none" + ZERO = "zero" + + @tf_export("gradients") def gradients(ys, xs, @@ -545,7 +566,8 @@ def gradients(ys, colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None, - stop_gradients=None): + stop_gradients=None, + unconnected_gradients=UnconnectedGradients.NONE): """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` @@ -596,6 +618,23 @@ def gradients(ys, All integer tensors are considered constant with respect to all `xs`, as if they were included in `stop_gradients`. + `unconnected_gradients` determines the value returned for each x in xs if it + is unconnected in the graph to ys. By default this is None to safeguard + against errors. MAthematically these gradients are zero which can be requested + using the `'zero'` option. `tf.UnconnectedGradients` provides the + following options and behaviors: + + ```python + a = tf.ones([1, 2]) + b = tf.ones([3, 1]) + g1 = tf.gradients([b], [a], unnconnected_gradients='none') + sess.run(g1) # [None] + + g2 = tf.gradients([b], [a], unconnected_gradients='zero') + sess.run(g2) # [array([[0., 0.]], dtype=float32)] + ``` + + Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. @@ -611,6 +650,10 @@ def gradients(ys, Accepted values are constants defined in the class `AggregationMethod`. stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate through. + unconnected_gradients: Optional. Specifies the gradient value returned when + the given input tensors are unconnected. Accepted values are constants + defined in the class `tf.UnconnectedGradients` and the default value is + `none`. Returns: A list of `sum(dy/dx)` for each x in `xs`. @@ -627,7 +670,8 @@ def gradients(ys, # mutating new ops. with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, - gate_gradients, aggregation_method, stop_gradients) + gate_gradients, aggregation_method, stop_gradients, + unconnected_gradients) def _GradientsHelper(ys, @@ -638,6 +682,7 @@ def _GradientsHelper(ys, gate_gradients=False, aggregation_method=None, stop_gradients=None, + unconnected_gradients=UnconnectedGradients.NONE, src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): @@ -645,6 +690,11 @@ def _GradientsHelper(ys, "is enabled. Use tf.GradientTape instead.") if src_graph is None: src_graph = ops.get_default_graph() + try: + unconnected_gradients = UnconnectedGradients(unconnected_gradients) + except ValueError: + raise ValueError( + "Unknown value for unconnected_gradients: %r" % unconnected_gradients) # If src_graph is a _FuncGraph (i.e. a function body), gather it and all # ancestor graphs. This is necessary for correctly handling captured values. @@ -856,7 +906,7 @@ def _GradientsHelper(ys, if loop_state: loop_state.PostProcessing() - return [_GetGrad(grads, x) for x in xs] + return [_GetGrad(grads, x, unconnected_gradients) for x in xs] def _HasAnyNotNoneGrads(grads, op): @@ -924,12 +974,19 @@ def _SetGrad(grads, t, grad): op_grads[t.value_index] = grad -def _GetGrad(grads, t): +def _GetGrad(grads, t, unconnected_gradients): """Gets gradient for tensor "t".""" op = t.op op_grads = grads.get(op) if not op_grads: - return None + if unconnected_gradients == UnconnectedGradients.ZERO: + return array_ops.zeros_like(t) + elif unconnected_gradients == UnconnectedGradients.NONE: + return None + else: + raise ValueError( + "Unknown value for unconnected_gradients: %r" % unconnected_gradients) + t_grad = op_grads[t.value_index] assert not isinstance( t_grad, list), ("gradients list should have been aggregated by now.") diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 3c9b7a01c70..c93e2493ee7 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase): for a, b in zip(npgrad1, npgrad2): np.testing.assert_allclose(a, b) + def testUnconnectedGradientsNoneUnconnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0, shape=[2, 2]) + y = constant(3.0, shape=[3, 1]) + grad = gradients.gradients( + [y], [x], unconnected_gradients="none") + self.assertIsNone(grad[0]) + + def testUnconnectedGradientsZerosUnconnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0, shape=[2, 2]) + y = constant(3.0, shape=[3, 1]) + grads = gradients.gradients( + [y], [x], unconnected_gradients="zero") + with self.cached_session() as sess: + self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0]) + + def testUnconnectedGradientsZeroConnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0) + y = x * 3.0 + grad = gradients.gradients( + [y], [x], unconnected_gradients="zero") + with self.cached_session() as sess: + self.assertEquals(3.0, sess.run(grad)[0]) + + def testUnknownUnconnectedGradientsValueGiven(self): + with ops.Graph().as_default(): + x = constant(1.0) + y = constant(1.0) + with self.assertRaisesRegexp( + ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): + gradients.gradients([y], [x], unconnected_gradients="nonsense") + class FunctionGradientsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt new file mode 100644 index 00000000000..c5eb9594308 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.UnconnectedGradients" +tf_class { + is_instance: "" + member { + name: "NONE" + mtype: "" + } + member { + name: "ZERO" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index a268529c1fa..c1cc7322f00 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -248,6 +248,10 @@ tf_module { name: "TextLineReader" mtype: "" } + member { + name: "UnconnectedGradients" + mtype: "" + } member { name: "VERSION" mtype: "" @@ -1234,7 +1238,7 @@ tf_module { } member_method { name: "gradients" - argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], " } member_method { name: "greater" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt new file mode 100644 index 00000000000..c5eb9594308 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.UnconnectedGradients" +tf_class { + is_instance: "" + member { + name: "NONE" + mtype: "" + } + member { + name: "ZERO" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 5b3ea75bce6..571abc3b194 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -220,6 +220,10 @@ tf_module { name: "TensorShape" mtype: "" } + member { + name: "UnconnectedGradients" + mtype: "" + } member { name: "VERSION" mtype: "" @@ -1134,7 +1138,7 @@ tf_module { } member_method { name: "gradients" - argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], " } member_method { name: "greater"