From eabf8538a081b97e0d5eb06df9558afca4463c3f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 9 Jan 2020 00:57:44 -0800 Subject: [PATCH] Fix tf.recompute_grad(f()) to work with f() that outputs sequence of tensors. PiperOrigin-RevId: 288851398 Change-Id: If2179deac3bfa0a0c9d881b7b7cf740c680b4d66 --- tensorflow/python/ops/custom_gradient.py | 11 ++++++----- tensorflow/python/ops/gradients_test.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 6421661d615..a5bdba123ef 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -455,8 +455,8 @@ def _eager_mode_decorator(f, args, kwargs): def recompute_grad(f): """An eager-compatible version of recompute_grad. - For f(*args, **kwargs), this supports gradients with respect to args, or to - gradients with respect to any variables residing in the kwarg 'variables'. + For f(*args, **kwargs), this supports gradients with respect to args or + kwargs, but kwargs are currently only supported in eager-mode. Note that for keras layer and model objects, this is handled automatically. Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not @@ -479,19 +479,20 @@ def recompute_grad(f): """Inner function closure for calculating gradients.""" result = f(*args, **kwargs) - def grad(dresult, variables=None): + def grad(*dresult, **grad_kwargs): """Gradient function calculation for inner function.""" + variables = grad_kwargs.get("variables") with backprop.GradientTape() as t: t.watch(args) if variables is not None: t.watch(variables) - with ops.control_dependencies([dresult]): + with ops.control_dependencies(dresult): result = f(*args, **kwargs) kw_vars = [] if variables is not None: kw_vars = list(variables) grads = t.gradient( - result, list(args) + kw_vars, output_gradients=[dresult]) + result, list(args) + kw_vars, output_gradients=dresult) return grads[:len(args)], grads[len(args):] return result, grad diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index e3886dd7ca2..139f7afc47f 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -1405,6 +1405,9 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): def TestFn(inputs, input_vars): return inputs * input_vars + def TestFnSeq(inputs, input_vars): + return (inputs * input_vars, inputs * input_vars * 2.0) + with variable_scope.variable_scope("test", use_resource=True): test_var = variable_scope.get_variable( name="test_var", @@ -1429,6 +1432,21 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): for g, g_re in zip(grads, grads_re): self.assertAllClose(g, g_re) + # Regression test for wrapping sequence outputting functions. + grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, + test_input) + grads_re = self.evaluate(grads_re) + grads = self.evaluate(grads) + for g, g_re in zip(grads, grads_re): + self.assertAllClose(g, g_re) + + grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, + test_var) + grads_re = self.evaluate(grads_re) + grads = self.evaluate(grads) + for g, g_re in zip(grads, grads_re): + self.assertAllClose(g, g_re) + class GradPassThroughTest(test_util.TensorFlowTestCase):