From 55987dbb42b8f066154759785e240c4fe8b825a9 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 29 May 2020 10:44:36 -0700 Subject: [PATCH] Adds `custom_grad` and `vjp` to tf_numpy/extensions and trax/math. Also changes tf.custom_gradient to allow nested structures as inputs (currently it only allows a list of tensors). PiperOrigin-RevId: 313808985 Change-Id: Ibf58547ec938c31324156d45b49771ceeafd10ab --- tensorflow/python/ops/custom_gradient.py | 35 +++++++--------- tensorflow/python/ops/gradients_test.py | 52 ++++-------------------- 2 files changed, 22 insertions(+), 65 deletions(-) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 953bb252729..2a9194fb146 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -175,23 +175,20 @@ def custom_gradient(f=None): Args: f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: - - `x` is a sequence of (nested structures of) `Tensor` inputs to the - function. - - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow - operations in `f` to `x`. + - `x` is a sequence of `Tensor` inputs to the function. + - `y` is a `Tensor` or sequence of `Tensor` outputs of applying + TensorFlow operations in `f` to `x`. - `grad_fn` is a function with the signature `g(*grad_ys)` which returns - a list of `Tensor`s the same size as (flattened) `x` - the derivatives - of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is - a sequence of `Tensor`s the same size as (flattened) `y` holding the - initial value gradients for each `Tensor` in `y`. - - In a pure mathematical sense, a vector-argument vector-valued function - `f`'s derivatives should be its Jacobian matrix `J`. Here we are - expressing the Jacobian `J` as a function `grad_fn` which defines how - `J` will transform a vector `grad_ys` when left-multiplied with it - (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional - representation of a matrix is convenient to use for chain-rule - calculation (in e.g. the back-propagation algorithm). + a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect + to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of + `Tensor`s the same size as `y` holding the initial value gradients for + each `Tensor` in `y`. In a pure mathematical sense, a vector-argument + vector-valued function `f`'s derivatives should be its Jacobian matrix + `J`. Here we are expressing the Jacobian `J` as a function `grad_fn` + which defines how `J` will transform a vector `grad_ys` when + left-multiplied with it (`grad_ys * J`). This functional representation + of a matrix is convenient to use for chain-rule calculation + (in e.g. the back-propagation algorithm). If `f` uses `Variable`s (that are not part of the inputs), i.e. through `get_variable`, then `grad_fn` should have @@ -310,7 +307,7 @@ def _graph_mode_decorator(f, args, kwargs): "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() - args = nest.map_structure(ops.convert_to_tensor, args) + args = [ops.convert_to_tensor(x) for x in args] # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. @@ -321,7 +318,6 @@ def _graph_mode_decorator(f, args, kwargs): ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) - args = nest.flatten(args) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() @@ -408,7 +404,6 @@ def _eager_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for eager mode.""" with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args, **kwargs) - args = nest.flatten(args) all_inputs = list(args) + list(kwargs.values()) # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. @@ -448,7 +443,7 @@ def _eager_mode_decorator(f, args, kwargs): raise ValueError( "custom_gradient function expected to return", arg_count, "gradients but returned", len(flat_grads), "instead.") - return flat_grads + variable_grads + return nest.flatten(input_grads) + variable_grads tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, actual_grad_fn) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9a9ce72a557..a06be7af74b 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -60,7 +60,6 @@ from tensorflow.python.ops import variables from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest from tensorflow.python.ops import gradient_checker_v2 -from tensorflow.python.util import nest class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -1040,7 +1039,7 @@ class GetDependentVariablesTest(test_util.TensorFlowTestCase): self.assertEqual(dependent_vars, [var]) -class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): +class CustomGradientTest(test_util.TensorFlowTestCase): def testCustomGradientTrivial(self): @@ -1120,7 +1119,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): out = core_layers.dense(x, 3, use_bias=False) def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) return grads[0], [array_ops.ones((4, 3))] @@ -1147,7 +1146,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): out = core_layers.dense(x, 3, use_bias=False) def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) return grads[0], [array_ops.ones((3, 3))] @@ -1186,7 +1185,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name del out_grad - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) @@ -1210,7 +1209,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name del out_grad - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) return out, Grad @@ -1274,7 +1273,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): out = core_layers.dense(x, 3, use_bias=False) def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) return grads[0], [array_ops.ones((4, 3))] @@ -1285,7 +1284,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): out = F(x) def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert + self.assertEqual(1, len(variables)) grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) return grads[0], [array_ops.ones((4, 3))] @@ -1304,43 +1303,6 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): dw = sess.run(math_ops.reduce_sum(grads[1])) self.assertEqual(12., dw) - @parameterized.named_parameters( - [(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension - x_struct, y_struct) - for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] - for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] - ]) - @test_util.run_in_graph_and_eager_modes - def testCustomGradientStructuralInputOutput(self, x_struct, y_struct): - """Tests that custom_gradient can handle structured inputs/outputs.""" - def Zeros(x): - return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x) - def GetStruct(x): - return nest.map_structure(lambda _: None, x) - - def MakeVjp(f, *x): - with backprop.GradientTape(persistent=True) as tape: - tape.watch(nest.flatten(x)) - y = f(*x) - def Vjp(dy): - return tape.gradient(y, x, output_gradients=dy) - return y, Vjp - - @custom_gradient.custom_gradient - def F(*x): - self.assertEqual(x_struct, GetStruct(x)) - def Vjp(*dy): - self.assertEqual(len(nest.flatten(y_struct)), - len(nest.flatten(dy))) - return nest.flatten(Zeros(x_struct)) - return Zeros(y_struct), Vjp - - x, dy = Zeros([x_struct, y_struct]) - y, vjp = MakeVjp(F, *x) - dx = vjp(dy) - self.assertEqual(x_struct, GetStruct(dx)) - self.assertEqual(y_struct, GetStruct(y)) - class TensorListGradientsTest(test_util.TensorFlowTestCase):