Add a "primals" argument to tf.custom_gradient
Makes higher-order custom gradients easier to define. They were possible before, but only by awkwardly wrapping the nested custom_gradient decorator (outer function takes output gradients, custom_gradient function takes primal inputs, inner function uses output gradients and probably ignores the primal inputs filtered through custom_gradient since they're easier to capture from the original source). PiperOrigin-RevId: 265982476
This commit is contained in:
parent
a1e5191049
commit
f3ceeb53eb
@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers.pooling import max_pooling3d
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -44,9 +45,21 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.signal import fft_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _chain_grads(primals, grad_fns):
|
||||
if len(grad_fns) == 1:
|
||||
return grad_fns[-1]
|
||||
@custom_gradient.custom_gradient(primals=primals)
|
||||
def grad(*args, **kwargs):
|
||||
return (grad_fns[0](*args, **kwargs),
|
||||
_chain_grads(primals, grad_fns[1:]))
|
||||
return grad
|
||||
|
||||
|
||||
class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
@ -1351,6 +1364,143 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
g = f(c)
|
||||
self.assertAllEqual(self.evaluate(t.gradient(g, c)), 4.0)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNthOrderCustomGradientsTape(self):
|
||||
|
||||
def _all_grads_tape(f, primals, doutputs):
|
||||
primals = nest.map_structure(ops.convert_to_tensor, primals)
|
||||
with backprop.GradientTape(persistent=True) as t:
|
||||
t.watch(primals)
|
||||
with variable_scope.variable_scope(
|
||||
# Required when graph building
|
||||
variable_scope.get_variable_scope(), use_resource=True):
|
||||
current = f(primals)
|
||||
ret = [current]
|
||||
for doutput in doutputs:
|
||||
current = t.gradient(current, primals, output_gradients=doutput,
|
||||
unconnected_gradients='zero')
|
||||
ret.append(current)
|
||||
return ret
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
y = 2. * x
|
||||
return y, _chain_grads(x, [lambda dy: dy * 2.1,
|
||||
lambda ddy: ddy * 2.2,
|
||||
lambda dddy: dddy * x * 2.3])
|
||||
|
||||
self.assertAllClose(
|
||||
[6., 4.2, 22.], _all_grads_tape(f, 3., [2., 10.]))
|
||||
self.assertAllClose(
|
||||
[6., 2.1, 2.2, 6.9, 2.3, 0.],
|
||||
_all_grads_tape(f, 3., [1., 1., 1., 1., 1.]))
|
||||
|
||||
traced_tape_grads = def_function.function(_all_grads_tape)
|
||||
self.assertAllClose(
|
||||
[6., 4.2, 22.], traced_tape_grads(f, 3., [2., 10.]))
|
||||
self.assertAllClose(
|
||||
[6., 2.1, 2.2, 6.9, 2.3, 0.],
|
||||
traced_tape_grads(f, 3., [1., 1., 1., 1., 1.]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNthOrderCustomGradientsTFGradients(self):
|
||||
|
||||
@def_function.function
|
||||
def _all_grads_tf_gradients(f, primals, doutputs):
|
||||
primals = nest.map_structure(ops.convert_to_tensor, primals)
|
||||
current = f(primals)
|
||||
ret = [current]
|
||||
for doutput in doutputs:
|
||||
current, = gradients.gradients(current, primals, grad_ys=doutput,
|
||||
unconnected_gradients='zero')
|
||||
ret.append(current)
|
||||
return ret
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
y = 2. * x
|
||||
return y, _chain_grads(x, [lambda dy: dy * 2.1,
|
||||
lambda ddy: ddy * 2.2,
|
||||
lambda dddy: dddy * x * 2.3])
|
||||
|
||||
self.assertAllClose(
|
||||
[6., 4.2, 22.], _all_grads_tf_gradients(f, 3., [2., 10.]))
|
||||
self.assertAllClose(
|
||||
[6., 2.1, 2.2, 6.9, 2.3, 0.], _all_grads_tf_gradients(
|
||||
f, 3., [1., 1., 1., 1., 1.]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomGradientManualNesting(self):
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x, y):
|
||||
z = 2. * x * y
|
||||
|
||||
@custom_gradient.custom_gradient(primals=(x, y))
|
||||
def g(unused_dz):
|
||||
|
||||
def h(unused_dz, unused_dydz):
|
||||
return (2.2, 3.2)
|
||||
|
||||
return (2.1, 3.1), h
|
||||
|
||||
return z, g
|
||||
|
||||
with backprop.GradientTape(persistent=True) as t:
|
||||
with backprop.GradientTape(persistent=True) as tt:
|
||||
c = constant_op.constant(1.)
|
||||
d = constant_op.constant(-1.)
|
||||
t.watch(c)
|
||||
tt.watch(c)
|
||||
t.watch(d)
|
||||
tt.watch(d)
|
||||
output = f(c, d)
|
||||
self.assertAllClose(-2., output)
|
||||
gc = tt.gradient(output, c)
|
||||
self.assertAllClose(2.1, gc)
|
||||
gd = tt.gradient(output, d)
|
||||
self.assertAllClose(3.1, gd)
|
||||
gcgc = t.gradient(gc, c)
|
||||
self.assertAllClose(2.2, gcgc)
|
||||
gcgd = t.gradient(gc, d)
|
||||
self.assertAllClose(3.2, gcgd)
|
||||
gdgc = t.gradient(gd, c)
|
||||
self.assertAllClose(2.2, gdgc)
|
||||
gdgd = t.gradient(gd, d)
|
||||
self.assertAllClose(3.2, gdgd)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomGradientForwardprop(self):
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
z = 2. * tensor_util.constant_value(x)
|
||||
def g(dz):
|
||||
@custom_gradient.custom_gradient
|
||||
def first_order(unused_x, unused_dz):
|
||||
def second_order_and_transpose(unused_ddz):
|
||||
return 2.2, 3.1
|
||||
return 2.1, second_order_and_transpose
|
||||
return first_order(x, dz)
|
||||
return z, g
|
||||
|
||||
with backprop.GradientTape(persistent=True) as t:
|
||||
with backprop.GradientTape() as tt:
|
||||
c = constant_op.constant(1.)
|
||||
t.watch(c)
|
||||
tt.watch(c)
|
||||
output_grad = array_ops.ones([])
|
||||
t.watch(output_grad)
|
||||
output = f(c)
|
||||
self.assertAllClose(2., output)
|
||||
gc = tt.gradient(output, c, output_gradients=output_grad)
|
||||
self.assertAllClose(2.1, gc)
|
||||
ggc = t.gradient(gc, c)
|
||||
self.assertAllClose(2.2, ggc)
|
||||
# Note that executed eagerly this kind of transpose is not efficient. But
|
||||
# from a tf.function we could prune out the first-order gradient
|
||||
# computation.
|
||||
transpose = t.gradient(gc, output_grad)
|
||||
self.assertAllClose(3.1, transpose)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMaxPooling3DGradient(self):
|
||||
|
||||
@ -1684,3 +1834,4 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
|
@ -83,7 +83,7 @@ def copy_handle_data(source_t, target_t):
|
||||
|
||||
|
||||
@tf_export("custom_gradient")
|
||||
def custom_gradient(f):
|
||||
def custom_gradient(f=None, primals=None):
|
||||
"""Decorator to define a function with a custom gradient.
|
||||
|
||||
This decorator allows fine grained control over the gradients of a sequence
|
||||
@ -122,6 +122,71 @@ def custom_gradient(f):
|
||||
With this definition, the gradient at x=100 will be correctly evaluated as
|
||||
1.0.
|
||||
|
||||
Nesting custom gradients can lead to unintuitive results. The default
|
||||
behavior does not correspond to n-th order derivatives. For example
|
||||
|
||||
```python
|
||||
@tf.custom_gradient
|
||||
def op(x):
|
||||
y = op1(x)
|
||||
@tf.custom_gradient
|
||||
def grad_fn(dy):
|
||||
gdy = op2(x, y, dy)
|
||||
def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x.
|
||||
return op3(x, y, dy, ddy)
|
||||
return gdy, grad_grad_fn
|
||||
return y, grad_fn
|
||||
```
|
||||
|
||||
The function `grad_grad_fn` will be calculating the first order gradient
|
||||
of `grad_fn` with respect to `dy`, which is used to generate forward-mode
|
||||
gradient graphs from backward-mode gradient graphs, but is not the same as
|
||||
the second order gradient of `op` with respect to `x`.
|
||||
|
||||
Instead, when overriding `n`-th order gradients, specify a `primals` argument
|
||||
to the inner decorator(s). For example overriding both first- and second-order
|
||||
gradients is necessary when making an operation with a fused forward and
|
||||
backward pass infinitely differentiable:
|
||||
|
||||
```python
|
||||
@tf.custom_gradient
|
||||
def op_with_fused_backprop(x):
|
||||
y, x_grad = fused_op(x)
|
||||
@tf.custom_gradient(primals=x)
|
||||
def grad_fn(dy):
|
||||
def grad_grad_fn(ddy):
|
||||
return infinitely_differentiable_second_order_grad_for_x(x, y, ddy)
|
||||
return x_grad, grad_grad_fn
|
||||
return y, grad_fn
|
||||
```
|
||||
|
||||
Likewise when also overriding third or higher-order gradients, `primals` will
|
||||
typically be the original zeroth-order inputs.
|
||||
|
||||
You can achieve the same effect by wrapping nested `@tf.custom_gradients` in
|
||||
another function. For example you may need to override gradients with respect
|
||||
to output gradients in addition to second-order gradients. Gradients with
|
||||
respect to output gradients are used for generating forward-mode gradient
|
||||
graphs from backward graphs, transposing the gradient function.
|
||||
|
||||
```python
|
||||
@tf.custom_gradient
|
||||
def op_with_fused_backprop(x):
|
||||
y, x_grad = fused_op(x)
|
||||
def first_order_gradient(dy):
|
||||
@tf.custom_gradient
|
||||
def first_order_custom(unused_x, unused_dy):
|
||||
def second_order_and_transpose(ddy):
|
||||
return second_order_for_x(...), gradient_wrt_dy(...)
|
||||
return x_grad, second_order_and_transpose
|
||||
return first_order_custom(x, dy)
|
||||
return y, first_order_gradient
|
||||
```
|
||||
|
||||
With the additional layer of nesting, `primals` is no longer
|
||||
necessary. Additional arguments to the inner `@tf.custom_gradient`-decorated
|
||||
function control the expected return values of the innermost function.
|
||||
|
||||
See also `tf.RegisterGradient` which registers a gradient function for a
|
||||
primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
|
||||
for fine grained control over the gradient computation of a sequence of
|
||||
@ -154,20 +219,30 @@ def custom_gradient(f):
|
||||
`grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
|
||||
with the derivatives of `Tensor`s in `y` with respect to the variables
|
||||
(that is, grad_vars has one Tensor per variable in variables).
|
||||
primals: A `Tensor` or list of `Tensor`. The tensors with respect to which
|
||||
the gradient function will be returning. When nesting custom gradients,
|
||||
specifying `primals` allows you to control which original tensors the
|
||||
higher-order gradients are for. See examples above.
|
||||
|
||||
Returns:
|
||||
A function `h(x)` which returns the same value as `f(x)[0]` and whose
|
||||
gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
|
||||
"""
|
||||
|
||||
def decorated(*args, **kwargs):
|
||||
"""Decorated function with custom gradient."""
|
||||
if context.executing_eagerly():
|
||||
return _eager_mode_decorator(f, *args, **kwargs)
|
||||
else:
|
||||
return _graph_mode_decorator(f, *args, **kwargs)
|
||||
def decorator(f):
|
||||
def decorated(*args, **kwargs):
|
||||
"""Decorated function with custom gradient."""
|
||||
if context.executing_eagerly():
|
||||
return _eager_mode_decorator(f, primals, *args, **kwargs)
|
||||
else:
|
||||
return _graph_mode_decorator(f, primals, *args, **kwargs)
|
||||
|
||||
return tf_decorator.make_decorator(f, decorated)
|
||||
return tf_decorator.make_decorator(f, decorated)
|
||||
|
||||
if f is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(f)
|
||||
|
||||
|
||||
def get_variable_by_name(var_name):
|
||||
@ -210,7 +285,7 @@ def get_dependent_variables(input_ops, output_ops):
|
||||
return tf_vars
|
||||
|
||||
|
||||
def _graph_mode_decorator(f, *args, **kwargs):
|
||||
def _graph_mode_decorator(f, primals, *args, **kwargs):
|
||||
"""Implement custom gradient decorator for graph mode."""
|
||||
# TODO(rsepassi): Add support for kwargs
|
||||
if kwargs:
|
||||
@ -269,7 +344,12 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
||||
"no ResourceVariables were used on the forward pass.")
|
||||
flat_result = nest.flatten(result)
|
||||
flat_result_len = len(flat_result)
|
||||
all_tensors = flat_result + args + variables
|
||||
|
||||
if primals is None:
|
||||
all_tensors = flat_result + args + variables
|
||||
else:
|
||||
primals = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
|
||||
all_tensors = flat_result + primals + variables
|
||||
|
||||
def tape_grad_fn(*result_grads):
|
||||
"""Custom grad fn wrapper."""
|
||||
@ -312,7 +392,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
||||
structure=result, flat_sequence=all_tensors[:flat_result_len])
|
||||
|
||||
|
||||
def _eager_mode_decorator(f, *args, **kwargs):
|
||||
def _eager_mode_decorator(f, primals, *args, **kwargs):
|
||||
"""Implement custom gradient decorator for eager mode."""
|
||||
with backprop.GradientTape() as tape:
|
||||
result, grad_fn = f(*args, **kwargs)
|
||||
@ -336,7 +416,14 @@ def _eager_mode_decorator(f, *args, **kwargs):
|
||||
|
||||
input_tensors = [ops.convert_to_tensor(x) for x
|
||||
in list(args) + list(variables)]
|
||||
arg_count = len(args)
|
||||
|
||||
if primals is None:
|
||||
recorded_inputs = input_tensors
|
||||
arg_count = len(args)
|
||||
else:
|
||||
recorded_inputs = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
|
||||
arg_count = len(recorded_inputs)
|
||||
|
||||
def actual_grad_fn(*result_grads):
|
||||
"""Custom grad fn wrapper."""
|
||||
if variables:
|
||||
@ -354,7 +441,7 @@ def _eager_mode_decorator(f, *args, **kwargs):
|
||||
"gradients but returned", len(flat_grads), "instead.")
|
||||
return nest.flatten(input_grads) + variable_grads
|
||||
|
||||
tape_lib.record_operation(f.__name__, flat_result, input_tensors,
|
||||
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
|
||||
actual_grad_fn)
|
||||
flat_result = list(flat_result)
|
||||
return nest.pack_sequence_as(result, flat_result)
|
||||
|
@ -1070,7 +1070,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "custom_gradient"
|
||||
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'f\', \'primals\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "decode_base64"
|
||||
|
@ -574,7 +574,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "custom_gradient"
|
||||
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'f\', \'primals\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "device"
|
||||
|
Loading…
Reference in New Issue
Block a user