Adds custom_grad and vjp to tf_numpy/extensions and trax/math.

Also changes tf.custom_gradient to allow nested structures as inputs (currently it only allows a list of tensors).

PiperOrigin-RevId: 314445524
Change-Id: I4b34f62a8b8e06db5bfefec6fcb2e79670c35bc3
This commit is contained in:
Peng Wang 2020-06-02 18:36:54 -07:00 committed by TensorFlower Gardener
parent 78390fab9a
commit 40c6e8d755
2 changed files with 67 additions and 22 deletions

View File

@ -175,20 +175,23 @@ def custom_gradient(f=None):
Args:
f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
- `x` is a sequence of `Tensor` inputs to the function.
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
TensorFlow operations in `f` to `x`.
- `x` is a sequence of (nested structures of) `Tensor` inputs to the
function.
- `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
`Tensor`s the same size as `y` holding the initial value gradients for
each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
vector-valued function `f`'s derivatives should be its Jacobian matrix
`J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
which defines how `J` will transform a vector `grad_ys` when
left-multiplied with it (`grad_ys * J`). This functional representation
of a matrix is convenient to use for chain-rule calculation
(in e.g. the back-propagation algorithm).
a list of `Tensor`s the same size as (flattened) `x` - the derivatives
of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is
a sequence of `Tensor`s the same size as (flattened) `y` holding the
initial value gradients for each `Tensor` in `y`.
In a pure mathematical sense, a vector-argument vector-valued function
`f`'s derivatives should be its Jacobian matrix `J`. Here we are
expressing the Jacobian `J` as a function `grad_fn` which defines how
`J` will transform a vector `grad_ys` when left-multiplied with it
(`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
representation of a matrix is convenient to use for chain-rule
calculation (in e.g. the back-propagation algorithm).
If `f` uses `Variable`s (that are not part of the
inputs), i.e. through `get_variable`, then `grad_fn` should have
@ -209,6 +212,8 @@ def custom_gradient(f=None):
@Bind.decorator
def decorated(wrapped, args, kwargs):
"""Decorated function with custom gradient."""
# raise ValueError("PW: trap")
if context.executing_eagerly():
return _eager_mode_decorator(wrapped, args, kwargs)
else:
@ -307,7 +312,7 @@ def _graph_mode_decorator(f, args, kwargs):
"The custom_gradient decorator currently supports keywords "
"arguments only when eager execution is enabled.")
name = "CustomGradient-%s" % ops.uid()
args = [ops.convert_to_tensor(x) for x in args]
args = nest.map_structure(ops.convert_to_tensor, args)
# Checking global and local variables attempts to ensure that no non-resource
# Variables are added to the graph.
@ -318,6 +323,7 @@ def _graph_mode_decorator(f, args, kwargs):
])
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args)
args = nest.flatten(args)
after_vars = set([
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
@ -404,6 +410,7 @@ def _eager_mode_decorator(f, args, kwargs):
"""Implement custom gradient decorator for eager mode."""
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args, **kwargs)
args = nest.flatten(args)
all_inputs = list(args) + list(kwargs.values())
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
@ -443,7 +450,7 @@ def _eager_mode_decorator(f, args, kwargs):
raise ValueError(
"custom_gradient function expected to return", arg_count,
"gradients but returned", len(flat_grads), "instead.")
return nest.flatten(input_grads) + variable_grads
return flat_grads + variable_grads
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
actual_grad_fn)

View File

@ -60,6 +60,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.util import nest
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@ -1039,7 +1040,7 @@ class GetDependentVariablesTest(test_util.TensorFlowTestCase):
self.assertEqual(dependent_vars, [var])
class CustomGradientTest(test_util.TensorFlowTestCase):
class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testCustomGradientTrivial(self):
@ -1119,7 +1120,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1146,7 +1147,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((3, 3))]
@ -1185,7 +1186,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
del out_grad
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
return (array_ops.ones((3, 2)),
[array_ops.ones((2, 4))])
@ -1209,7 +1210,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
del out_grad
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
return out, Grad
@ -1273,7 +1274,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1284,7 +1285,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
out = F(x)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
@ -1303,6 +1304,43 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
dw = sess.run(math_ops.reduce_sum(grads[1]))
self.assertEqual(12., dw)
@parameterized.named_parameters(
[(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension
x_struct, y_struct)
for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
])
@test_util.run_in_graph_and_eager_modes
def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
"""Tests that custom_gradient can handle structured inputs/outputs."""
def Zeros(x):
return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
def GetStruct(x):
return nest.map_structure(lambda _: None, x)
def MakeVjp(f, *x):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(nest.flatten(x))
y = f(*x)
def Vjp(dy):
return tape.gradient(y, x, output_gradients=dy)
return y, Vjp
@custom_gradient.custom_gradient
def F(*x):
self.assertEqual(x_struct, GetStruct(x))
def Vjp(*dy):
self.assertEqual(len(nest.flatten(y_struct)),
len(nest.flatten(dy)))
return nest.flatten(Zeros(x_struct))
return Zeros(y_struct), Vjp
x, dy = Zeros([x_struct, y_struct])
y, vjp = MakeVjp(F, *x)
dx = vjp(dy)
self.assertEqual(x_struct, GetStruct(dx))
self.assertEqual(y_struct, GetStruct(y))
class TensorListGradientsTest(test_util.TensorFlowTestCase):