custom_gradient functions should be able to return their inputs
PiperOrigin-RevId: 173723462
This commit is contained in:
parent
78bac7290c
commit
02f55400f8
@ -569,5 +569,17 @@ class BackpropTest(test.TestCase):
|
||||
var.assign_sub(lr*grad)
|
||||
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
|
||||
|
||||
def testCustomGradientIdentity(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def my_identity(x):
|
||||
|
||||
def grad(dresult):
|
||||
return [2 * dresult]
|
||||
|
||||
return x, grad
|
||||
|
||||
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
@ -72,17 +73,19 @@ def custom_gradient(f):
|
||||
|
||||
with tape.stop_recording():
|
||||
result, grad_fn = f(*args, **kwargs)
|
||||
flat_result = nest.flatten(result)
|
||||
# TODO(apassos) consider removing the identity below.
|
||||
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||
|
||||
def actual_grad_fn(*outputs):
|
||||
return nest.flatten(grad_fn(*outputs))
|
||||
|
||||
flat_result = nest.flatten(result)
|
||||
tape.record_operation(
|
||||
f.__name__,
|
||||
flat_result,
|
||||
input_tensors,
|
||||
actual_grad_fn)
|
||||
flat_result = list(flat_result)
|
||||
return result
|
||||
return nest.pack_sequence_as(result, flat_result)
|
||||
|
||||
return tf_decorator.make_decorator(f, decorated)
|
||||
|
Loading…
x
Reference in New Issue
Block a user