diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 6ede02dbcdc..be733405a39 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest @@ -322,7 +323,10 @@ def implicit_val_and_grad(f): ``` Args: - f: The function to be differentiated. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. Returns: A function which, when called, returns a tuple pair. @@ -384,7 +388,10 @@ def implicit_grad(f): ``` Args: - f: The function to be differentiated. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. Returns: A function which, when called, returns a list of (gradient, variable) pairs. @@ -467,7 +474,12 @@ def gradients_function(f, params=None): ``` Args: - f: function to be differentiated. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. If desired, the tensors can be elementwise multiplied by the + tensors passed as the `dy` keyword argument to the returned gradient + function. params: list of parameter names of f or list of integers indexing the parameters with respect to which we'll differentiate. Passing None differentiates with respect to all parameters. @@ -559,7 +571,12 @@ def val_and_grad_function(f, params=None): ``` Args: - f: function to be differentiated. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. If desired, the tensors can be elementwise multiplied by the + tensors passed as the `dy` keyword argument to the returned gradient + function. params: list of parameter names of f or list of integers indexing the parameters with respect to which we'll differentiate. Passing `None` differentiates with respect to all parameters. @@ -632,12 +649,17 @@ def make_vjp(f, params=None): sources.append(args[i]) tape.watch(args[i]) result = f(*args) + flat_result = nest.flatten(result) + flat_result = [gen_array_ops.identity(x) for x in flat_result] + result = nest.pack_sequence_as(result, flat_result) finally: t = tape.pop_tape() def vjp(dy=None): + if dy is not None: + dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] return imperative_grad.imperative_grad( _default_vspace, t, nest.flatten(result), sources, - output_gradients=nest.flatten(dy) if dy is not None else None) + output_gradients=dy) return result, vjp return decorated @@ -697,7 +719,7 @@ _default_vspace = imperative_grad.VSpace( aggregate_fn=_aggregate_grads, tensor_id=ops.tensor_id, zeros=array_ops.zeros, - ones_like=array_ops.ones_like) + ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x))) class GradientTape(object): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 20532c8ee8e..cf736fcb13e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -401,6 +401,13 @@ class BackpropTest(test.TestCase): backprop.gradients_function(part)(constant_op.constant(1.0))[0], 2.0) + def testReturnSameThing(self): + + def f(x): + return x, 2 * x + + self.assertAllEqual(backprop.gradients_function(f)(1.0)[0], 3.0) + def testExceptionSafety(self): def f(unused_x):