Adds `custom_grad` and `vjp` to tf_numpy/extensions and trax/math.
Also changes tf.custom_gradient to allow nested structures as inputs (currently it only allows a list of tensors). PiperOrigin-RevId: 313808985 Change-Id: Ibf58547ec938c31324156d45b49771ceeafd10ab
This commit is contained in:
parent
f36a1a090e
commit
55987dbb42
|
@ -175,23 +175,20 @@ def custom_gradient(f=None):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
|
f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
|
||||||
- `x` is a sequence of (nested structures of) `Tensor` inputs to the
|
- `x` is a sequence of `Tensor` inputs to the function.
|
||||||
function.
|
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
|
||||||
- `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
|
TensorFlow operations in `f` to `x`.
|
||||||
operations in `f` to `x`.
|
|
||||||
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
|
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
|
||||||
a list of `Tensor`s the same size as (flattened) `x` - the derivatives
|
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
|
||||||
of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is
|
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
|
||||||
a sequence of `Tensor`s the same size as (flattened) `y` holding the
|
`Tensor`s the same size as `y` holding the initial value gradients for
|
||||||
initial value gradients for each `Tensor` in `y`.
|
each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
|
||||||
|
vector-valued function `f`'s derivatives should be its Jacobian matrix
|
||||||
In a pure mathematical sense, a vector-argument vector-valued function
|
`J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
|
||||||
`f`'s derivatives should be its Jacobian matrix `J`. Here we are
|
which defines how `J` will transform a vector `grad_ys` when
|
||||||
expressing the Jacobian `J` as a function `grad_fn` which defines how
|
left-multiplied with it (`grad_ys * J`). This functional representation
|
||||||
`J` will transform a vector `grad_ys` when left-multiplied with it
|
of a matrix is convenient to use for chain-rule calculation
|
||||||
(`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
|
(in e.g. the back-propagation algorithm).
|
||||||
representation of a matrix is convenient to use for chain-rule
|
|
||||||
calculation (in e.g. the back-propagation algorithm).
|
|
||||||
|
|
||||||
If `f` uses `Variable`s (that are not part of the
|
If `f` uses `Variable`s (that are not part of the
|
||||||
inputs), i.e. through `get_variable`, then `grad_fn` should have
|
inputs), i.e. through `get_variable`, then `grad_fn` should have
|
||||||
|
@ -310,7 +307,7 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||||
"The custom_gradient decorator currently supports keywords "
|
"The custom_gradient decorator currently supports keywords "
|
||||||
"arguments only when eager execution is enabled.")
|
"arguments only when eager execution is enabled.")
|
||||||
name = "CustomGradient-%s" % ops.uid()
|
name = "CustomGradient-%s" % ops.uid()
|
||||||
args = nest.map_structure(ops.convert_to_tensor, args)
|
args = [ops.convert_to_tensor(x) for x in args]
|
||||||
|
|
||||||
# Checking global and local variables attempts to ensure that no non-resource
|
# Checking global and local variables attempts to ensure that no non-resource
|
||||||
# Variables are added to the graph.
|
# Variables are added to the graph.
|
||||||
|
@ -321,7 +318,6 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||||
])
|
])
|
||||||
with tape_lib.VariableWatcher() as variable_watcher:
|
with tape_lib.VariableWatcher() as variable_watcher:
|
||||||
result, grad_fn = f(*args)
|
result, grad_fn = f(*args)
|
||||||
args = nest.flatten(args)
|
|
||||||
after_vars = set([
|
after_vars = set([
|
||||||
v.ref() for v in current_var_scope.global_variables() +
|
v.ref() for v in current_var_scope.global_variables() +
|
||||||
current_var_scope.local_variables()
|
current_var_scope.local_variables()
|
||||||
|
@ -408,7 +404,6 @@ def _eager_mode_decorator(f, args, kwargs):
|
||||||
"""Implement custom gradient decorator for eager mode."""
|
"""Implement custom gradient decorator for eager mode."""
|
||||||
with tape_lib.VariableWatcher() as variable_watcher:
|
with tape_lib.VariableWatcher() as variable_watcher:
|
||||||
result, grad_fn = f(*args, **kwargs)
|
result, grad_fn = f(*args, **kwargs)
|
||||||
args = nest.flatten(args)
|
|
||||||
all_inputs = list(args) + list(kwargs.values())
|
all_inputs = list(args) + list(kwargs.values())
|
||||||
# The variables that grad_fn needs to return gradients for are the set of
|
# The variables that grad_fn needs to return gradients for are the set of
|
||||||
# variables used that are *not* part of the inputs.
|
# variables used that are *not* part of the inputs.
|
||||||
|
@ -448,7 +443,7 @@ def _eager_mode_decorator(f, args, kwargs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"custom_gradient function expected to return", arg_count,
|
"custom_gradient function expected to return", arg_count,
|
||||||
"gradients but returned", len(flat_grads), "instead.")
|
"gradients but returned", len(flat_grads), "instead.")
|
||||||
return flat_grads + variable_grads
|
return nest.flatten(input_grads) + variable_grads
|
||||||
|
|
||||||
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
|
tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
|
||||||
actual_grad_fn)
|
actual_grad_fn)
|
||||||
|
|
|
@ -60,7 +60,6 @@ from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.nn_ops import bias_add
|
from tensorflow.python.ops.nn_ops import bias_add
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
@ -1040,7 +1039,7 @@ class GetDependentVariablesTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(dependent_vars, [var])
|
self.assertEqual(dependent_vars, [var])
|
||||||
|
|
||||||
|
|
||||||
class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
class CustomGradientTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testCustomGradientTrivial(self):
|
def testCustomGradientTrivial(self):
|
||||||
|
|
||||||
|
@ -1120,7 +1119,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
out = core_layers.dense(x, 3, use_bias=False)
|
out = core_layers.dense(x, 3, use_bias=False)
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
return grads[0], [array_ops.ones((4, 3))]
|
return grads[0], [array_ops.ones((4, 3))]
|
||||||
|
|
||||||
|
@ -1147,7 +1146,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
out = core_layers.dense(x, 3, use_bias=False)
|
out = core_layers.dense(x, 3, use_bias=False)
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
return grads[0], [array_ops.ones((3, 3))]
|
return grads[0], [array_ops.ones((3, 3))]
|
||||||
|
|
||||||
|
@ -1186,7 +1185,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
del out_grad
|
del out_grad
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
return (array_ops.ones((3, 2)),
|
return (array_ops.ones((3, 2)),
|
||||||
[array_ops.ones((2, 4))])
|
[array_ops.ones((2, 4))])
|
||||||
|
|
||||||
|
@ -1210,7 +1209,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
del out_grad
|
del out_grad
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
|
return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
|
||||||
|
|
||||||
return out, Grad
|
return out, Grad
|
||||||
|
@ -1274,7 +1273,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
out = core_layers.dense(x, 3, use_bias=False)
|
out = core_layers.dense(x, 3, use_bias=False)
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
return grads[0], [array_ops.ones((4, 3))]
|
return grads[0], [array_ops.ones((4, 3))]
|
||||||
|
|
||||||
|
@ -1285,7 +1284,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
out = F(x)
|
out = F(x)
|
||||||
|
|
||||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert
|
self.assertEqual(1, len(variables))
|
||||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
return grads[0], [array_ops.ones((4, 3))]
|
return grads[0], [array_ops.ones((4, 3))]
|
||||||
|
|
||||||
|
@ -1304,43 +1303,6 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
dw = sess.run(math_ops.reduce_sum(grads[1]))
|
dw = sess.run(math_ops.reduce_sum(grads[1]))
|
||||||
self.assertEqual(12., dw)
|
self.assertEqual(12., dw)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
[(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension
|
|
||||||
x_struct, y_struct)
|
|
||||||
for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
|
|
||||||
for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
|
|
||||||
])
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
|
|
||||||
"""Tests that custom_gradient can handle structured inputs/outputs."""
|
|
||||||
def Zeros(x):
|
|
||||||
return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
|
|
||||||
def GetStruct(x):
|
|
||||||
return nest.map_structure(lambda _: None, x)
|
|
||||||
|
|
||||||
def MakeVjp(f, *x):
|
|
||||||
with backprop.GradientTape(persistent=True) as tape:
|
|
||||||
tape.watch(nest.flatten(x))
|
|
||||||
y = f(*x)
|
|
||||||
def Vjp(dy):
|
|
||||||
return tape.gradient(y, x, output_gradients=dy)
|
|
||||||
return y, Vjp
|
|
||||||
|
|
||||||
@custom_gradient.custom_gradient
|
|
||||||
def F(*x):
|
|
||||||
self.assertEqual(x_struct, GetStruct(x))
|
|
||||||
def Vjp(*dy):
|
|
||||||
self.assertEqual(len(nest.flatten(y_struct)),
|
|
||||||
len(nest.flatten(dy)))
|
|
||||||
return nest.flatten(Zeros(x_struct))
|
|
||||||
return Zeros(y_struct), Vjp
|
|
||||||
|
|
||||||
x, dy = Zeros([x_struct, y_struct])
|
|
||||||
y, vjp = MakeVjp(F, *x)
|
|
||||||
dx = vjp(dy)
|
|
||||||
self.assertEqual(x_struct, GetStruct(dx))
|
|
||||||
self.assertEqual(y_struct, GetStruct(y))
|
|
||||||
|
|
||||||
|
|
||||||
class TensorListGradientsTest(test_util.TensorFlowTestCase):
|
class TensorListGradientsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue