Add a "primals" argument to tf.custom_gradient

Makes higher-order custom gradients easier to define. They were possible before, but only by awkwardly wrapping the nested custom_gradient decorator (outer function takes output gradients, custom_gradient function takes primal inputs, inner function uses output gradients and probably ignores the primal inputs filtered through custom_gradient since they're easier to capture from the original source).

PiperOrigin-RevId: 265982476
This commit is contained in:
Allen Lavoie 2019-08-28 13:39:03 -07:00 committed by TensorFlower Gardener
parent a1e5191049
commit f3ceeb53eb
4 changed files with 253 additions and 15 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.layers.pooling import max_pooling3d
from tensorflow.python.ops import array_ops
@ -44,9 +45,21 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
from tensorflow.python.util import nest
def _chain_grads(primals, grad_fns):
if len(grad_fns) == 1:
return grad_fns[-1]
@custom_gradient.custom_gradient(primals=primals)
def grad(*args, **kwargs):
return (grad_fns[0](*args, **kwargs),
_chain_grads(primals, grad_fns[1:]))
return grad
class BackpropTest(test.TestCase, parameterized.TestCase):
@ -1351,6 +1364,143 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
g = f(c)
self.assertAllEqual(self.evaluate(t.gradient(g, c)), 4.0)
@test_util.run_in_graph_and_eager_modes
def testNthOrderCustomGradientsTape(self):
def _all_grads_tape(f, primals, doutputs):
primals = nest.map_structure(ops.convert_to_tensor, primals)
with backprop.GradientTape(persistent=True) as t:
t.watch(primals)
with variable_scope.variable_scope(
# Required when graph building
variable_scope.get_variable_scope(), use_resource=True):
current = f(primals)
ret = [current]
for doutput in doutputs:
current = t.gradient(current, primals, output_gradients=doutput,
unconnected_gradients='zero')
ret.append(current)
return ret
@custom_gradient.custom_gradient
def f(x):
y = 2. * x
return y, _chain_grads(x, [lambda dy: dy * 2.1,
lambda ddy: ddy * 2.2,
lambda dddy: dddy * x * 2.3])
self.assertAllClose(
[6., 4.2, 22.], _all_grads_tape(f, 3., [2., 10.]))
self.assertAllClose(
[6., 2.1, 2.2, 6.9, 2.3, 0.],
_all_grads_tape(f, 3., [1., 1., 1., 1., 1.]))
traced_tape_grads = def_function.function(_all_grads_tape)
self.assertAllClose(
[6., 4.2, 22.], traced_tape_grads(f, 3., [2., 10.]))
self.assertAllClose(
[6., 2.1, 2.2, 6.9, 2.3, 0.],
traced_tape_grads(f, 3., [1., 1., 1., 1., 1.]))
@test_util.run_in_graph_and_eager_modes
def testNthOrderCustomGradientsTFGradients(self):
@def_function.function
def _all_grads_tf_gradients(f, primals, doutputs):
primals = nest.map_structure(ops.convert_to_tensor, primals)
current = f(primals)
ret = [current]
for doutput in doutputs:
current, = gradients.gradients(current, primals, grad_ys=doutput,
unconnected_gradients='zero')
ret.append(current)
return ret
@custom_gradient.custom_gradient
def f(x):
y = 2. * x
return y, _chain_grads(x, [lambda dy: dy * 2.1,
lambda ddy: ddy * 2.2,
lambda dddy: dddy * x * 2.3])
self.assertAllClose(
[6., 4.2, 22.], _all_grads_tf_gradients(f, 3., [2., 10.]))
self.assertAllClose(
[6., 2.1, 2.2, 6.9, 2.3, 0.], _all_grads_tf_gradients(
f, 3., [1., 1., 1., 1., 1.]))
@test_util.run_in_graph_and_eager_modes
def testCustomGradientManualNesting(self):
@custom_gradient.custom_gradient
def f(x, y):
z = 2. * x * y
@custom_gradient.custom_gradient(primals=(x, y))
def g(unused_dz):
def h(unused_dz, unused_dydz):
return (2.2, 3.2)
return (2.1, 3.1), h
return z, g
with backprop.GradientTape(persistent=True) as t:
with backprop.GradientTape(persistent=True) as tt:
c = constant_op.constant(1.)
d = constant_op.constant(-1.)
t.watch(c)
tt.watch(c)
t.watch(d)
tt.watch(d)
output = f(c, d)
self.assertAllClose(-2., output)
gc = tt.gradient(output, c)
self.assertAllClose(2.1, gc)
gd = tt.gradient(output, d)
self.assertAllClose(3.1, gd)
gcgc = t.gradient(gc, c)
self.assertAllClose(2.2, gcgc)
gcgd = t.gradient(gc, d)
self.assertAllClose(3.2, gcgd)
gdgc = t.gradient(gd, c)
self.assertAllClose(2.2, gdgc)
gdgd = t.gradient(gd, d)
self.assertAllClose(3.2, gdgd)
@test_util.run_in_graph_and_eager_modes
def testCustomGradientForwardprop(self):
@custom_gradient.custom_gradient
def f(x):
z = 2. * tensor_util.constant_value(x)
def g(dz):
@custom_gradient.custom_gradient
def first_order(unused_x, unused_dz):
def second_order_and_transpose(unused_ddz):
return 2.2, 3.1
return 2.1, second_order_and_transpose
return first_order(x, dz)
return z, g
with backprop.GradientTape(persistent=True) as t:
with backprop.GradientTape() as tt:
c = constant_op.constant(1.)
t.watch(c)
tt.watch(c)
output_grad = array_ops.ones([])
t.watch(output_grad)
output = f(c)
self.assertAllClose(2., output)
gc = tt.gradient(output, c, output_gradients=output_grad)
self.assertAllClose(2.1, gc)
ggc = t.gradient(gc, c)
self.assertAllClose(2.2, ggc)
# Note that executed eagerly this kind of transpose is not efficient. But
# from a tf.function we could prune out the first-order gradient
# computation.
transpose = t.gradient(gc, output_grad)
self.assertAllClose(3.1, transpose)
@test_util.run_in_graph_and_eager_modes
def testMaxPooling3DGradient(self):
@ -1684,3 +1834,4 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
test.main()

View File

@ -83,7 +83,7 @@ def copy_handle_data(source_t, target_t):
@tf_export("custom_gradient")
def custom_gradient(f):
def custom_gradient(f=None, primals=None):
"""Decorator to define a function with a custom gradient.
This decorator allows fine grained control over the gradients of a sequence
@ -122,6 +122,71 @@ def custom_gradient(f):
With this definition, the gradient at x=100 will be correctly evaluated as
1.0.
Nesting custom gradients can lead to unintuitive results. The default
behavior does not correspond to n-th order derivatives. For example
```python
@tf.custom_gradient
def op(x):
y = op1(x)
@tf.custom_gradient
def grad_fn(dy):
gdy = op2(x, y, dy)
def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x.
return op3(x, y, dy, ddy)
return gdy, grad_grad_fn
return y, grad_fn
```
The function `grad_grad_fn` will be calculating the first order gradient
of `grad_fn` with respect to `dy`, which is used to generate forward-mode
gradient graphs from backward-mode gradient graphs, but is not the same as
the second order gradient of `op` with respect to `x`.
Instead, when overriding `n`-th order gradients, specify a `primals` argument
to the inner decorator(s). For example overriding both first- and second-order
gradients is necessary when making an operation with a fused forward and
backward pass infinitely differentiable:
```python
@tf.custom_gradient
def op_with_fused_backprop(x):
y, x_grad = fused_op(x)
@tf.custom_gradient(primals=x)
def grad_fn(dy):
def grad_grad_fn(ddy):
return infinitely_differentiable_second_order_grad_for_x(x, y, ddy)
return x_grad, grad_grad_fn
return y, grad_fn
```
Likewise when also overriding third or higher-order gradients, `primals` will
typically be the original zeroth-order inputs.
You can achieve the same effect by wrapping nested `@tf.custom_gradients` in
another function. For example you may need to override gradients with respect
to output gradients in addition to second-order gradients. Gradients with
respect to output gradients are used for generating forward-mode gradient
graphs from backward graphs, transposing the gradient function.
```python
@tf.custom_gradient
def op_with_fused_backprop(x):
y, x_grad = fused_op(x)
def first_order_gradient(dy):
@tf.custom_gradient
def first_order_custom(unused_x, unused_dy):
def second_order_and_transpose(ddy):
return second_order_for_x(...), gradient_wrt_dy(...)
return x_grad, second_order_and_transpose
return first_order_custom(x, dy)
return y, first_order_gradient
```
With the additional layer of nesting, `primals` is no longer
necessary. Additional arguments to the inner `@tf.custom_gradient`-decorated
function control the expected return values of the innermost function.
See also `tf.RegisterGradient` which registers a gradient function for a
primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
for fine grained control over the gradient computation of a sequence of
@ -154,20 +219,30 @@ def custom_gradient(f):
`grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
with the derivatives of `Tensor`s in `y` with respect to the variables
(that is, grad_vars has one Tensor per variable in variables).
primals: A `Tensor` or list of `Tensor`. The tensors with respect to which
the gradient function will be returning. When nesting custom gradients,
specifying `primals` allows you to control which original tensors the
higher-order gradients are for. See examples above.
Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
"""
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
if context.executing_eagerly():
return _eager_mode_decorator(f, *args, **kwargs)
else:
return _graph_mode_decorator(f, *args, **kwargs)
def decorator(f):
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
if context.executing_eagerly():
return _eager_mode_decorator(f, primals, *args, **kwargs)
else:
return _graph_mode_decorator(f, primals, *args, **kwargs)
return tf_decorator.make_decorator(f, decorated)
return tf_decorator.make_decorator(f, decorated)
if f is None:
return decorator
else:
return decorator(f)
def get_variable_by_name(var_name):
@ -210,7 +285,7 @@ def get_dependent_variables(input_ops, output_ops):
return tf_vars
def _graph_mode_decorator(f, *args, **kwargs):
def _graph_mode_decorator(f, primals, *args, **kwargs):
"""Implement custom gradient decorator for graph mode."""
# TODO(rsepassi): Add support for kwargs
if kwargs:
@ -269,7 +344,12 @@ def _graph_mode_decorator(f, *args, **kwargs):
"no ResourceVariables were used on the forward pass.")
flat_result = nest.flatten(result)
flat_result_len = len(flat_result)
all_tensors = flat_result + args + variables
if primals is None:
all_tensors = flat_result + args + variables
else:
primals = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
all_tensors = flat_result + primals + variables
def tape_grad_fn(*result_grads):
"""Custom grad fn wrapper."""
@ -312,7 +392,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
structure=result, flat_sequence=all_tensors[:flat_result_len])
def _eager_mode_decorator(f, *args, **kwargs):
def _eager_mode_decorator(f, primals, *args, **kwargs):
"""Implement custom gradient decorator for eager mode."""
with backprop.GradientTape() as tape:
result, grad_fn = f(*args, **kwargs)
@ -336,7 +416,14 @@ def _eager_mode_decorator(f, *args, **kwargs):
input_tensors = [ops.convert_to_tensor(x) for x
in list(args) + list(variables)]
arg_count = len(args)
if primals is None:
recorded_inputs = input_tensors
arg_count = len(args)
else:
recorded_inputs = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
arg_count = len(recorded_inputs)
def actual_grad_fn(*result_grads):
"""Custom grad fn wrapper."""
if variables:
@ -354,7 +441,7 @@ def _eager_mode_decorator(f, *args, **kwargs):
"gradients but returned", len(flat_grads), "instead.")
return nest.flatten(input_grads) + variable_grads
tape_lib.record_operation(f.__name__, flat_result, input_tensors,
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
actual_grad_fn)
flat_result = list(flat_result)
return nest.pack_sequence_as(result, flat_result)

View File

@ -1070,7 +1070,7 @@ tf_module {
}
member_method {
name: "custom_gradient"
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'f\', \'primals\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "decode_base64"

View File

@ -574,7 +574,7 @@ tf_module {
}
member_method {
name: "custom_gradient"
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'f\', \'primals\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "device"