Modernize tf.gradients docstring.
PiperOrigin-RevId: 308424217 Change-Id: I53a8a78bdb8a99edf1e456185b6df8953887e43e
This commit is contained in:
parent
dac6d6ae7c
commit
ea8087efe5
@ -184,6 +184,10 @@ def gradients_v2(ys, # pylint: disable=invalid-name
|
|||||||
unconnected_gradients=UnconnectedGradients.NONE):
|
unconnected_gradients=UnconnectedGradients.NONE):
|
||||||
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
|
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
|
||||||
|
|
||||||
|
`tf.gradients` is only valid in a graph context. In particular,
|
||||||
|
it is valid in the context of a `tf.function` wrapper, where code
|
||||||
|
is executing as a graph.
|
||||||
|
|
||||||
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
|
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
|
||||||
is a list of `Tensor`, holding the gradients received by the
|
is a list of `Tensor`, holding the gradients received by the
|
||||||
`ys`. The list must be the same length as `ys`.
|
`ys`. The list must be the same length as `ys`.
|
||||||
@ -206,22 +210,28 @@ def gradients_v2(ys, # pylint: disable=invalid-name
|
|||||||
other things, this allows computation of partial derivatives as opposed to
|
other things, this allows computation of partial derivatives as opposed to
|
||||||
total derivatives. For example:
|
total derivatives. For example:
|
||||||
|
|
||||||
```python
|
>>> @tf.function
|
||||||
a = tf.constant(0.)
|
... def example():
|
||||||
b = 2 * a
|
... a = tf.constant(0.)
|
||||||
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
|
... b = 2 * a
|
||||||
```
|
... return tf.gradients(a + b, [a, b], stop_gradients=[a, b])
|
||||||
|
>>> example()
|
||||||
|
[<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
|
||||||
|
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
|
||||||
|
|
||||||
Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
|
Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
|
||||||
total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
|
total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
|
||||||
influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is
|
influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is
|
||||||
equivalent to:
|
equivalent to:
|
||||||
|
|
||||||
```python
|
>>> @tf.function
|
||||||
a = tf.stop_gradient(tf.constant(0.))
|
... def example():
|
||||||
b = tf.stop_gradient(2 * a)
|
... a = tf.stop_gradient(tf.constant(0.))
|
||||||
g = tf.gradients(a + b, [a, b])
|
... b = tf.stop_gradient(2 * a)
|
||||||
```
|
... return tf.gradients(a + b, [a, b])
|
||||||
|
>>> example()
|
||||||
|
[<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
|
||||||
|
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
|
||||||
|
|
||||||
`stop_gradients` provides a way of stopping gradient after the graph has
|
`stop_gradients` provides a way of stopping gradient after the graph has
|
||||||
already been constructed, as compared to `tf.stop_gradient` which is used
|
already been constructed, as compared to `tf.stop_gradient` which is used
|
||||||
@ -238,29 +248,35 @@ def gradients_v2(ys, # pylint: disable=invalid-name
|
|||||||
using the `'zero'` option. `tf.UnconnectedGradients` provides the
|
using the `'zero'` option. `tf.UnconnectedGradients` provides the
|
||||||
following options and behaviors:
|
following options and behaviors:
|
||||||
|
|
||||||
```python
|
>>> @tf.function
|
||||||
a = tf.ones([1, 2])
|
... def example(use_zero):
|
||||||
b = tf.ones([3, 1])
|
... a = tf.ones([1, 2])
|
||||||
g1 = tf.gradients([b], [a], unconnected_gradients='none')
|
... b = tf.ones([3, 1])
|
||||||
sess.run(g1) # [None]
|
... if use_zero:
|
||||||
|
... return tf.gradients([b], [a], unconnected_gradients='zero')
|
||||||
g2 = tf.gradients([b], [a], unconnected_gradients='zero')
|
... else:
|
||||||
sess.run(g2) # [array([[0., 0.]], dtype=float32)]
|
... return tf.gradients([b], [a], unconnected_gradients='none')
|
||||||
```
|
>>> example(False)
|
||||||
|
[None]
|
||||||
|
>>> example(True)
|
||||||
|
[<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 0.]], ...)>]
|
||||||
|
|
||||||
Let us take one practical example which comes during the back propogation
|
Let us take one practical example which comes during the back propogation
|
||||||
phase. This function is used to evaluate the derivatives of the cost function
|
phase. This function is used to evaluate the derivatives of the cost function
|
||||||
with respect to Weights `Ws` and Biases `bs`. Below sample implementation
|
with respect to Weights `Ws` and Biases `bs`. Below sample implementation
|
||||||
provides the exaplantion of what it is actually used for :
|
provides the exaplantion of what it is actually used for :
|
||||||
|
|
||||||
```python
|
>>> @tf.function
|
||||||
Ws = tf.constant(0.)
|
... def example():
|
||||||
bs = 2 * Ws
|
... Ws = tf.constant(0.)
|
||||||
cost = Ws + bs # This is just an example. So, please ignore the formulas.
|
... bs = 2 * Ws
|
||||||
g = tf.gradients(cost, [Ws, bs])
|
... cost = Ws + bs # This is just an example. Please ignore the formulas.
|
||||||
dCost_dW, dCost_db = g
|
... g = tf.gradients(cost, [Ws, bs])
|
||||||
```
|
... dCost_dW, dCost_db = g
|
||||||
|
... return dCost_dW, dCost_db
|
||||||
|
>>> example()
|
||||||
|
(<tf.Tensor: shape=(), dtype=float32, numpy=3.0>,
|
||||||
|
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ys: A `Tensor` or list of tensors to be differentiated.
|
ys: A `Tensor` or list of tensors to be differentiated.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user