Fix tf.recompute_grad(f()) to work with f() that outputs sequence of tensors.

PiperOrigin-RevId: 288851398
Change-Id: If2179deac3bfa0a0c9d881b7b7cf740c680b4d66
This commit is contained in:
A. Unique TensorFlower 2020-01-09 00:57:44 -08:00 committed by TensorFlower Gardener
parent d80fda0877
commit eabf8538a0
2 changed files with 24 additions and 5 deletions

View File

@ -455,8 +455,8 @@ def _eager_mode_decorator(f, args, kwargs):
def recompute_grad(f): def recompute_grad(f):
"""An eager-compatible version of recompute_grad. """An eager-compatible version of recompute_grad.
For f(*args, **kwargs), this supports gradients with respect to args, or to For f(*args, **kwargs), this supports gradients with respect to args or
gradients with respect to any variables residing in the kwarg 'variables'. kwargs, but kwargs are currently only supported in eager-mode.
Note that for keras layer and model objects, this is handled automatically. Note that for keras layer and model objects, this is handled automatically.
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
@ -479,19 +479,20 @@ def recompute_grad(f):
"""Inner function closure for calculating gradients.""" """Inner function closure for calculating gradients."""
result = f(*args, **kwargs) result = f(*args, **kwargs)
def grad(dresult, variables=None): def grad(*dresult, **grad_kwargs):
"""Gradient function calculation for inner function.""" """Gradient function calculation for inner function."""
variables = grad_kwargs.get("variables")
with backprop.GradientTape() as t: with backprop.GradientTape() as t:
t.watch(args) t.watch(args)
if variables is not None: if variables is not None:
t.watch(variables) t.watch(variables)
with ops.control_dependencies([dresult]): with ops.control_dependencies(dresult):
result = f(*args, **kwargs) result = f(*args, **kwargs)
kw_vars = [] kw_vars = []
if variables is not None: if variables is not None:
kw_vars = list(variables) kw_vars = list(variables)
grads = t.gradient( grads = t.gradient(
result, list(args) + kw_vars, output_gradients=[dresult]) result, list(args) + kw_vars, output_gradients=dresult)
return grads[:len(args)], grads[len(args):] return grads[:len(args)], grads[len(args):]
return result, grad return result, grad

View File

@ -1405,6 +1405,9 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
def TestFn(inputs, input_vars): def TestFn(inputs, input_vars):
return inputs * input_vars return inputs * input_vars
def TestFnSeq(inputs, input_vars):
return (inputs * input_vars, inputs * input_vars * 2.0)
with variable_scope.variable_scope("test", use_resource=True): with variable_scope.variable_scope("test", use_resource=True):
test_var = variable_scope.get_variable( test_var = variable_scope.get_variable(
name="test_var", name="test_var",
@ -1429,6 +1432,21 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
for g, g_re in zip(grads, grads_re): for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re) self.assertAllClose(g, g_re)
# Regression test for wrapping sequence outputting functions.
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
test_input)
grads_re = self.evaluate(grads_re)
grads = self.evaluate(grads)
for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re)
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
test_var)
grads_re = self.evaluate(grads_re)
grads = self.evaluate(grads)
for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re)
class GradPassThroughTest(test_util.TensorFlowTestCase): class GradPassThroughTest(test_util.TensorFlowTestCase):