custom_gradient functions should be able to return their inputs
PiperOrigin-RevId: 173723462
This commit is contained in:
parent
78bac7290c
commit
02f55400f8
@ -569,5 +569,17 @@ class BackpropTest(test.TestCase):
|
|||||||
var.assign_sub(lr*grad)
|
var.assign_sub(lr*grad)
|
||||||
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
|
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
|
||||||
|
|
||||||
|
def testCustomGradientIdentity(self):
|
||||||
|
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def my_identity(x):
|
||||||
|
|
||||||
|
def grad(dresult):
|
||||||
|
return [2 * dresult]
|
||||||
|
|
||||||
|
return x, grad
|
||||||
|
|
||||||
|
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import ops as tf_ops
|
from tensorflow.python.framework import ops as tf_ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
@ -72,17 +73,19 @@ def custom_gradient(f):
|
|||||||
|
|
||||||
with tape.stop_recording():
|
with tape.stop_recording():
|
||||||
result, grad_fn = f(*args, **kwargs)
|
result, grad_fn = f(*args, **kwargs)
|
||||||
|
flat_result = nest.flatten(result)
|
||||||
|
# TODO(apassos) consider removing the identity below.
|
||||||
|
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||||
|
|
||||||
def actual_grad_fn(*outputs):
|
def actual_grad_fn(*outputs):
|
||||||
return nest.flatten(grad_fn(*outputs))
|
return nest.flatten(grad_fn(*outputs))
|
||||||
|
|
||||||
flat_result = nest.flatten(result)
|
|
||||||
tape.record_operation(
|
tape.record_operation(
|
||||||
f.__name__,
|
f.__name__,
|
||||||
flat_result,
|
flat_result,
|
||||||
input_tensors,
|
input_tensors,
|
||||||
actual_grad_fn)
|
actual_grad_fn)
|
||||||
flat_result = list(flat_result)
|
flat_result = list(flat_result)
|
||||||
return result
|
return nest.pack_sequence_as(result, flat_result)
|
||||||
|
|
||||||
return tf_decorator.make_decorator(f, decorated)
|
return tf_decorator.make_decorator(f, decorated)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user