custom_gradient functions should be able to return their inputs

PiperOrigin-RevId: 173723462
This commit is contained in:
Alexandre Passos 2017-10-27 15:08:01 -07:00 committed by TensorFlower Gardener
parent 78bac7290c
commit 02f55400f8
2 changed files with 17 additions and 2 deletions

View File

@ -569,5 +569,17 @@ class BackpropTest(test.TestCase):
var.assign_sub(lr*grad)
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
def testCustomGradientIdentity(self):
@custom_gradient.custom_gradient
def my_identity(x):
def grad(dresult):
return [2 * dresult]
return x, grad
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
if __name__ == '__main__':
test.main()

View File

@ -22,6 +22,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@ -72,17 +73,19 @@ def custom_gradient(f):
with tape.stop_recording():
result, grad_fn = f(*args, **kwargs)
flat_result = nest.flatten(result)
# TODO(apassos) consider removing the identity below.
flat_result = [gen_array_ops.identity(x) for x in flat_result]
def actual_grad_fn(*outputs):
return nest.flatten(grad_fn(*outputs))
flat_result = nest.flatten(result)
tape.record_operation(
f.__name__,
flat_result,
input_tensors,
actual_grad_fn)
flat_result = list(flat_result)
return result
return nest.pack_sequence_as(result, flat_result)
return tf_decorator.make_decorator(f, decorated)