Fix tf.recompute_grad(f()) to work with f() that outputs sequence of tensors.

PiperOrigin-RevId: 288851398
Change-Id: If2179deac3bfa0a0c9d881b7b7cf740c680b4d66
This commit is contained in:
A. Unique TensorFlower 2020-01-09 00:57:44 -08:00 committed by TensorFlower Gardener
parent d80fda0877
commit eabf8538a0
2 changed files with 24 additions and 5 deletions

View File

@ -455,8 +455,8 @@ def _eager_mode_decorator(f, args, kwargs):
def recompute_grad(f):
"""An eager-compatible version of recompute_grad.
For f(*args, **kwargs), this supports gradients with respect to args, or to
gradients with respect to any variables residing in the kwarg 'variables'.
For f(*args, **kwargs), this supports gradients with respect to args or
kwargs, but kwargs are currently only supported in eager-mode.
Note that for keras layer and model objects, this is handled automatically.
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
@ -479,19 +479,20 @@ def recompute_grad(f):
"""Inner function closure for calculating gradients."""
result = f(*args, **kwargs)
def grad(dresult, variables=None):
def grad(*dresult, **grad_kwargs):
"""Gradient function calculation for inner function."""
variables = grad_kwargs.get("variables")
with backprop.GradientTape() as t:
t.watch(args)
if variables is not None:
t.watch(variables)
with ops.control_dependencies([dresult]):
with ops.control_dependencies(dresult):
result = f(*args, **kwargs)
kw_vars = []
if variables is not None:
kw_vars = list(variables)
grads = t.gradient(
result, list(args) + kw_vars, output_gradients=[dresult])
result, list(args) + kw_vars, output_gradients=dresult)
return grads[:len(args)], grads[len(args):]
return result, grad

View File

@ -1405,6 +1405,9 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
def TestFn(inputs, input_vars):
return inputs * input_vars
def TestFnSeq(inputs, input_vars):
return (inputs * input_vars, inputs * input_vars * 2.0)
with variable_scope.variable_scope("test", use_resource=True):
test_var = variable_scope.get_variable(
name="test_var",
@ -1429,6 +1432,21 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re)
# Regression test for wrapping sequence outputting functions.
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
test_input)
grads_re = self.evaluate(grads_re)
grads = self.evaluate(grads)
for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re)
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
test_var)
grads_re = self.evaluate(grads_re)
grads = self.evaluate(grads)
for g, g_re in zip(grads, grads_re):
self.assertAllClose(g, g_re)
class GradPassThroughTest(test_util.TensorFlowTestCase):