Fix tf.recompute_grad(f()) to work with f() that outputs sequence of tensors.
PiperOrigin-RevId: 288851398 Change-Id: If2179deac3bfa0a0c9d881b7b7cf740c680b4d66
This commit is contained in:
parent
d80fda0877
commit
eabf8538a0
@ -455,8 +455,8 @@ def _eager_mode_decorator(f, args, kwargs):
|
||||
def recompute_grad(f):
|
||||
"""An eager-compatible version of recompute_grad.
|
||||
|
||||
For f(*args, **kwargs), this supports gradients with respect to args, or to
|
||||
gradients with respect to any variables residing in the kwarg 'variables'.
|
||||
For f(*args, **kwargs), this supports gradients with respect to args or
|
||||
kwargs, but kwargs are currently only supported in eager-mode.
|
||||
Note that for keras layer and model objects, this is handled automatically.
|
||||
|
||||
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
|
||||
@ -479,19 +479,20 @@ def recompute_grad(f):
|
||||
"""Inner function closure for calculating gradients."""
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
def grad(dresult, variables=None):
|
||||
def grad(*dresult, **grad_kwargs):
|
||||
"""Gradient function calculation for inner function."""
|
||||
variables = grad_kwargs.get("variables")
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(args)
|
||||
if variables is not None:
|
||||
t.watch(variables)
|
||||
with ops.control_dependencies([dresult]):
|
||||
with ops.control_dependencies(dresult):
|
||||
result = f(*args, **kwargs)
|
||||
kw_vars = []
|
||||
if variables is not None:
|
||||
kw_vars = list(variables)
|
||||
grads = t.gradient(
|
||||
result, list(args) + kw_vars, output_gradients=[dresult])
|
||||
result, list(args) + kw_vars, output_gradients=dresult)
|
||||
return grads[:len(args)], grads[len(args):]
|
||||
|
||||
return result, grad
|
||||
|
@ -1405,6 +1405,9 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
def TestFn(inputs, input_vars):
|
||||
return inputs * input_vars
|
||||
|
||||
def TestFnSeq(inputs, input_vars):
|
||||
return (inputs * input_vars, inputs * input_vars * 2.0)
|
||||
|
||||
with variable_scope.variable_scope("test", use_resource=True):
|
||||
test_var = variable_scope.get_variable(
|
||||
name="test_var",
|
||||
@ -1429,6 +1432,21 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
for g, g_re in zip(grads, grads_re):
|
||||
self.assertAllClose(g, g_re)
|
||||
|
||||
# Regression test for wrapping sequence outputting functions.
|
||||
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
|
||||
test_input)
|
||||
grads_re = self.evaluate(grads_re)
|
||||
grads = self.evaluate(grads)
|
||||
for g, g_re in zip(grads, grads_re):
|
||||
self.assertAllClose(g, g_re)
|
||||
|
||||
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
|
||||
test_var)
|
||||
grads_re = self.evaluate(grads_re)
|
||||
grads = self.evaluate(grads)
|
||||
for g, g_re in zip(grads, grads_re):
|
||||
self.assertAllClose(g, g_re)
|
||||
|
||||
|
||||
class GradPassThroughTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user