Adds `custom_grad` and `vjp` to tf_numpy/extensions and trax/math.

Also changes tf.custom_gradient to allow nested structures as inputs (currently it only allows a list of tensors).

PiperOrigin-RevId: 313808985
Change-Id: Ibf58547ec938c31324156d45b49771ceeafd10ab
This commit is contained in:
Christian Sigg 2020-05-29 10:44:36 -07:00 committed by TensorFlower Gardener
parent f36a1a090e
commit 55987dbb42
2 changed files with 22 additions and 65 deletions

View File

@ -175,23 +175,20 @@ def custom_gradient(f=None):
Args:
f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
- `x` is a sequence of (nested structures of) `Tensor` inputs to the
function.
- `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
operations in `f` to `x`.
- `x` is a sequence of `Tensor` inputs to the function.
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
TensorFlow operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s the same size as (flattened) `x` - the derivatives
of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is
a sequence of `Tensor`s the same size as (flattened) `y` holding the
initial value gradients for each `Tensor` in `y`.
In a pure mathematical sense, a vector-argument vector-valued function
`f`'s derivatives should be its Jacobian matrix `J`. Here we are
expressing the Jacobian `J` as a function `grad_fn` which defines how
`J` will transform a vector `grad_ys` when left-multiplied with it
(`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
representation of a matrix is convenient to use for chain-rule
calculation (in e.g. the back-propagation algorithm).
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
`Tensor`s the same size as `y` holding the initial value gradients for
each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
vector-valued function `f`'s derivatives should be its Jacobian matrix
`J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
which defines how `J` will transform a vector `grad_ys` when
left-multiplied with it (`grad_ys * J`). This functional representation
of a matrix is convenient to use for chain-rule calculation
(in e.g. the back-propagation algorithm).
If `f` uses `Variable`s (that are not part of the
inputs), i.e. through `get_variable`, then `grad_fn` should have
@ -310,7 +307,7 @@ def _graph_mode_decorator(f, args, kwargs):
"The custom_gradient decorator currently supports keywords "
"arguments only when eager execution is enabled.")
name = "CustomGradient-%s" % ops.uid()
args = nest.map_structure(ops.convert_to_tensor, args)
args = [ops.convert_to_tensor(x) for x in args]
# Checking global and local variables attempts to ensure that no non-resource
# Variables are added to the graph.
@ -321,7 +318,6 @@ def _graph_mode_decorator(f, args, kwargs):
])
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args)
args = nest.flatten(args)
after_vars = set([
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
@ -408,7 +404,6 @@ def _eager_mode_decorator(f, args, kwargs):
"""Implement custom gradient decorator for eager mode."""
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args, **kwargs)
args = nest.flatten(args)
all_inputs = list(args) + list(kwargs.values())
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
@ -448,7 +443,7 @@ def _eager_mode_decorator(f, args, kwargs):
raise ValueError(
"custom_gradient function expected to return", arg_count,
"gradients but returned", len(flat_grads), "instead.")
return flat_grads + variable_grads
return nest.flatten(input_grads) + variable_grads
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
actual_grad_fn)

View File

@ -60,7 +60,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.util import nest
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@ -1040,7 +1039,7 @@ class GetDependentVariablesTest(test_util.TensorFlowTestCase):
self.assertEqual(dependent_vars, [var])
class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
class CustomGradientTest(test_util.TensorFlowTestCase):
def testCustomGradientTrivial(self):
@ -1120,7 +1119,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1147,7 +1146,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((3, 3))]
@ -1186,7 +1185,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
del out_grad
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
return (array_ops.ones((3, 2)),
[array_ops.ones((2, 4))])
@ -1210,7 +1209,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
del out_grad
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
return out, Grad
@ -1274,7 +1273,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1285,7 +1284,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
out = F(x)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1304,43 +1303,6 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
dw = sess.run(math_ops.reduce_sum(grads[1]))
self.assertEqual(12., dw)
@parameterized.named_parameters(
[(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension
x_struct, y_struct)
for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
])
@test_util.run_in_graph_and_eager_modes
def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
"""Tests that custom_gradient can handle structured inputs/outputs."""
def Zeros(x):
return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
def GetStruct(x):
return nest.map_structure(lambda _: None, x)
def MakeVjp(f, *x):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(nest.flatten(x))
y = f(*x)
def Vjp(dy):
return tape.gradient(y, x, output_gradients=dy)
return y, Vjp
@custom_gradient.custom_gradient
def F(*x):
self.assertEqual(x_struct, GetStruct(x))
def Vjp(*dy):
self.assertEqual(len(nest.flatten(y_struct)),
len(nest.flatten(dy)))
return nest.flatten(Zeros(x_struct))
return Zeros(y_struct), Vjp
x, dy = Zeros([x_struct, y_struct])
y, vjp = MakeVjp(F, *x)
dx = vjp(dy)
self.assertEqual(x_struct, GetStruct(dx))
self.assertEqual(y_struct, GetStruct(y))
class TensorListGradientsTest(test_util.TensorFlowTestCase):