Fix issue with gradients of functions which return multiple values.
PiperOrigin-RevId: 173730922
This commit is contained in:
parent
80374a7b47
commit
e1d7615ebc
@ -35,6 +35,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.util import nest
|
||||
@ -322,7 +323,10 @@ def implicit_val_and_grad(f):
|
||||
```
|
||||
|
||||
Args:
|
||||
f: The function to be differentiated.
|
||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||
a scalar will be computed by adding all their values to produce a single
|
||||
scalar.
|
||||
|
||||
Returns:
|
||||
A function which, when called, returns a tuple pair.
|
||||
@ -384,7 +388,10 @@ def implicit_grad(f):
|
||||
```
|
||||
|
||||
Args:
|
||||
f: The function to be differentiated.
|
||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||
a scalar will be computed by adding all their values to produce a single
|
||||
scalar.
|
||||
|
||||
Returns:
|
||||
A function which, when called, returns a list of (gradient, variable) pairs.
|
||||
@ -467,7 +474,12 @@ def gradients_function(f, params=None):
|
||||
```
|
||||
|
||||
Args:
|
||||
f: function to be differentiated.
|
||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||
a scalar will be computed by adding all their values to produce a single
|
||||
scalar. If desired, the tensors can be elementwise multiplied by the
|
||||
tensors passed as the `dy` keyword argument to the returned gradient
|
||||
function.
|
||||
params: list of parameter names of f or list of integers indexing the
|
||||
parameters with respect to which we'll differentiate. Passing None
|
||||
differentiates with respect to all parameters.
|
||||
@ -559,7 +571,12 @@ def val_and_grad_function(f, params=None):
|
||||
```
|
||||
|
||||
Args:
|
||||
f: function to be differentiated.
|
||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||
a scalar will be computed by adding all their values to produce a single
|
||||
scalar. If desired, the tensors can be elementwise multiplied by the
|
||||
tensors passed as the `dy` keyword argument to the returned gradient
|
||||
function.
|
||||
params: list of parameter names of f or list of integers indexing the
|
||||
parameters with respect to which we'll differentiate. Passing `None`
|
||||
differentiates with respect to all parameters.
|
||||
@ -632,12 +649,17 @@ def make_vjp(f, params=None):
|
||||
sources.append(args[i])
|
||||
tape.watch(args[i])
|
||||
result = f(*args)
|
||||
flat_result = nest.flatten(result)
|
||||
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||
result = nest.pack_sequence_as(result, flat_result)
|
||||
finally:
|
||||
t = tape.pop_tape()
|
||||
def vjp(dy=None):
|
||||
if dy is not None:
|
||||
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
|
||||
return imperative_grad.imperative_grad(
|
||||
_default_vspace, t, nest.flatten(result), sources,
|
||||
output_gradients=nest.flatten(dy) if dy is not None else None)
|
||||
output_gradients=dy)
|
||||
return result, vjp
|
||||
|
||||
return decorated
|
||||
@ -697,7 +719,7 @@ _default_vspace = imperative_grad.VSpace(
|
||||
aggregate_fn=_aggregate_grads,
|
||||
tensor_id=ops.tensor_id,
|
||||
zeros=array_ops.zeros,
|
||||
ones_like=array_ops.ones_like)
|
||||
ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
|
||||
|
||||
|
||||
class GradientTape(object):
|
||||
|
@ -401,6 +401,13 @@ class BackpropTest(test.TestCase):
|
||||
backprop.gradients_function(part)(constant_op.constant(1.0))[0],
|
||||
2.0)
|
||||
|
||||
def testReturnSameThing(self):
|
||||
|
||||
def f(x):
|
||||
return x, 2 * x
|
||||
|
||||
self.assertAllEqual(backprop.gradients_function(f)(1.0)[0], 3.0)
|
||||
|
||||
def testExceptionSafety(self):
|
||||
|
||||
def f(unused_x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user