Adds `custom_grad` and `vjp` to tf_numpy/extensions and trax/math.
Also changes tf.custom_gradient to allow nested structures as inputs (currently it only allows a list of tensors). PiperOrigin-RevId: 313808985 Change-Id: Ibf58547ec938c31324156d45b49771ceeafd10ab
This commit is contained in:
parent
f36a1a090e
commit
55987dbb42
|
@ -175,23 +175,20 @@ def custom_gradient(f=None):
|
|||
|
||||
Args:
|
||||
f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
|
||||
- `x` is a sequence of (nested structures of) `Tensor` inputs to the
|
||||
function.
|
||||
- `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
|
||||
operations in `f` to `x`.
|
||||
- `x` is a sequence of `Tensor` inputs to the function.
|
||||
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
|
||||
TensorFlow operations in `f` to `x`.
|
||||
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
|
||||
a list of `Tensor`s the same size as (flattened) `x` - the derivatives
|
||||
of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is
|
||||
a sequence of `Tensor`s the same size as (flattened) `y` holding the
|
||||
initial value gradients for each `Tensor` in `y`.
|
||||
|
||||
In a pure mathematical sense, a vector-argument vector-valued function
|
||||
`f`'s derivatives should be its Jacobian matrix `J`. Here we are
|
||||
expressing the Jacobian `J` as a function `grad_fn` which defines how
|
||||
`J` will transform a vector `grad_ys` when left-multiplied with it
|
||||
(`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
|
||||
representation of a matrix is convenient to use for chain-rule
|
||||
calculation (in e.g. the back-propagation algorithm).
|
||||
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
|
||||
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
|
||||
`Tensor`s the same size as `y` holding the initial value gradients for
|
||||
each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
|
||||
vector-valued function `f`'s derivatives should be its Jacobian matrix
|
||||
`J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
|
||||
which defines how `J` will transform a vector `grad_ys` when
|
||||
left-multiplied with it (`grad_ys * J`). This functional representation
|
||||
of a matrix is convenient to use for chain-rule calculation
|
||||
(in e.g. the back-propagation algorithm).
|
||||
|
||||
If `f` uses `Variable`s (that are not part of the
|
||||
inputs), i.e. through `get_variable`, then `grad_fn` should have
|
||||
|
@ -310,7 +307,7 @@ def _graph_mode_decorator(f, args, kwargs):
|
|||
"The custom_gradient decorator currently supports keywords "
|
||||
"arguments only when eager execution is enabled.")
|
||||
name = "CustomGradient-%s" % ops.uid()
|
||||
args = nest.map_structure(ops.convert_to_tensor, args)
|
||||
args = [ops.convert_to_tensor(x) for x in args]
|
||||
|
||||
# Checking global and local variables attempts to ensure that no non-resource
|
||||
# Variables are added to the graph.
|
||||
|
@ -321,7 +318,6 @@ def _graph_mode_decorator(f, args, kwargs):
|
|||
])
|
||||
with tape_lib.VariableWatcher() as variable_watcher:
|
||||
result, grad_fn = f(*args)
|
||||
args = nest.flatten(args)
|
||||
after_vars = set([
|
||||
v.ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()
|
||||
|
@ -408,7 +404,6 @@ def _eager_mode_decorator(f, args, kwargs):
|
|||
"""Implement custom gradient decorator for eager mode."""
|
||||
with tape_lib.VariableWatcher() as variable_watcher:
|
||||
result, grad_fn = f(*args, **kwargs)
|
||||
args = nest.flatten(args)
|
||||
all_inputs = list(args) + list(kwargs.values())
|
||||
# The variables that grad_fn needs to return gradients for are the set of
|
||||
# variables used that are *not* part of the inputs.
|
||||
|
@ -448,7 +443,7 @@ def _eager_mode_decorator(f, args, kwargs):
|
|||
raise ValueError(
|
||||
"custom_gradient function expected to return", arg_count,
|
||||
"gradients but returned", len(flat_grads), "instead.")
|
||||
return flat_grads + variable_grads
|
||||
return nest.flatten(input_grads) + variable_grads
|
||||
|
||||
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
|
||||
actual_grad_fn)
|
||||
|
|
|
@ -60,7 +60,6 @@ from tensorflow.python.ops import variables
|
|||
from tensorflow.python.ops.nn_ops import bias_add
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
@ -1040,7 +1039,7 @@ class GetDependentVariablesTest(test_util.TensorFlowTestCase):
|
|||
self.assertEqual(dependent_vars, [var])
|
||||
|
||||
|
||||
class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
class CustomGradientTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testCustomGradientTrivial(self):
|
||||
|
||||
|
@ -1120,7 +1119,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
out = core_layers.dense(x, 3, use_bias=False)
|
||||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||
return grads[0], [array_ops.ones((4, 3))]
|
||||
|
||||
|
@ -1147,7 +1146,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
out = core_layers.dense(x, 3, use_bias=False)
|
||||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||
return grads[0], [array_ops.ones((3, 3))]
|
||||
|
||||
|
@ -1186,7 +1185,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
del out_grad
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
return (array_ops.ones((3, 2)),
|
||||
[array_ops.ones((2, 4))])
|
||||
|
||||
|
@ -1210,7 +1209,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
del out_grad
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
|
||||
|
||||
return out, Grad
|
||||
|
@ -1274,7 +1273,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
out = core_layers.dense(x, 3, use_bias=False)
|
||||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||
return grads[0], [array_ops.ones((4, 3))]
|
||||
|
||||
|
@ -1285,7 +1284,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
out = F(x)
|
||||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
||||
self.assertEqual(1, len(variables))
|
||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||
return grads[0], [array_ops.ones((4, 3))]
|
||||
|
||||
|
@ -1304,43 +1303,6 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
dw = sess.run(math_ops.reduce_sum(grads[1]))
|
||||
self.assertEqual(12., dw)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension
|
||||
x_struct, y_struct)
|
||||
for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
|
||||
for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
|
||||
])
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
|
||||
"""Tests that custom_gradient can handle structured inputs/outputs."""
|
||||
def Zeros(x):
|
||||
return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
|
||||
def GetStruct(x):
|
||||
return nest.map_structure(lambda _: None, x)
|
||||
|
||||
def MakeVjp(f, *x):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(nest.flatten(x))
|
||||
y = f(*x)
|
||||
def Vjp(dy):
|
||||
return tape.gradient(y, x, output_gradients=dy)
|
||||
return y, Vjp
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def F(*x):
|
||||
self.assertEqual(x_struct, GetStruct(x))
|
||||
def Vjp(*dy):
|
||||
self.assertEqual(len(nest.flatten(y_struct)),
|
||||
len(nest.flatten(dy)))
|
||||
return nest.flatten(Zeros(x_struct))
|
||||
return Zeros(y_struct), Vjp
|
||||
|
||||
x, dy = Zeros([x_struct, y_struct])
|
||||
y, vjp = MakeVjp(F, *x)
|
||||
dx = vjp(dy)
|
||||
self.assertEqual(x_struct, GetStruct(dx))
|
||||
self.assertEqual(y_struct, GetStruct(y))
|
||||
|
||||
|
||||
class TensorListGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue