Fix tf.recompute_grad(f()) to work with f() that outputs sequence of tensors.
PiperOrigin-RevId: 288851398 Change-Id: If2179deac3bfa0a0c9d881b7b7cf740c680b4d66
This commit is contained in:
parent
d80fda0877
commit
eabf8538a0
@ -455,8 +455,8 @@ def _eager_mode_decorator(f, args, kwargs):
|
|||||||
def recompute_grad(f):
|
def recompute_grad(f):
|
||||||
"""An eager-compatible version of recompute_grad.
|
"""An eager-compatible version of recompute_grad.
|
||||||
|
|
||||||
For f(*args, **kwargs), this supports gradients with respect to args, or to
|
For f(*args, **kwargs), this supports gradients with respect to args or
|
||||||
gradients with respect to any variables residing in the kwarg 'variables'.
|
kwargs, but kwargs are currently only supported in eager-mode.
|
||||||
Note that for keras layer and model objects, this is handled automatically.
|
Note that for keras layer and model objects, this is handled automatically.
|
||||||
|
|
||||||
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
|
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
|
||||||
@ -479,19 +479,20 @@ def recompute_grad(f):
|
|||||||
"""Inner function closure for calculating gradients."""
|
"""Inner function closure for calculating gradients."""
|
||||||
result = f(*args, **kwargs)
|
result = f(*args, **kwargs)
|
||||||
|
|
||||||
def grad(dresult, variables=None):
|
def grad(*dresult, **grad_kwargs):
|
||||||
"""Gradient function calculation for inner function."""
|
"""Gradient function calculation for inner function."""
|
||||||
|
variables = grad_kwargs.get("variables")
|
||||||
with backprop.GradientTape() as t:
|
with backprop.GradientTape() as t:
|
||||||
t.watch(args)
|
t.watch(args)
|
||||||
if variables is not None:
|
if variables is not None:
|
||||||
t.watch(variables)
|
t.watch(variables)
|
||||||
with ops.control_dependencies([dresult]):
|
with ops.control_dependencies(dresult):
|
||||||
result = f(*args, **kwargs)
|
result = f(*args, **kwargs)
|
||||||
kw_vars = []
|
kw_vars = []
|
||||||
if variables is not None:
|
if variables is not None:
|
||||||
kw_vars = list(variables)
|
kw_vars = list(variables)
|
||||||
grads = t.gradient(
|
grads = t.gradient(
|
||||||
result, list(args) + kw_vars, output_gradients=[dresult])
|
result, list(args) + kw_vars, output_gradients=dresult)
|
||||||
return grads[:len(args)], grads[len(args):]
|
return grads[:len(args)], grads[len(args):]
|
||||||
|
|
||||||
return result, grad
|
return result, grad
|
||||||
|
@ -1405,6 +1405,9 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
|||||||
def TestFn(inputs, input_vars):
|
def TestFn(inputs, input_vars):
|
||||||
return inputs * input_vars
|
return inputs * input_vars
|
||||||
|
|
||||||
|
def TestFnSeq(inputs, input_vars):
|
||||||
|
return (inputs * input_vars, inputs * input_vars * 2.0)
|
||||||
|
|
||||||
with variable_scope.variable_scope("test", use_resource=True):
|
with variable_scope.variable_scope("test", use_resource=True):
|
||||||
test_var = variable_scope.get_variable(
|
test_var = variable_scope.get_variable(
|
||||||
name="test_var",
|
name="test_var",
|
||||||
@ -1429,6 +1432,21 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
|||||||
for g, g_re in zip(grads, grads_re):
|
for g, g_re in zip(grads, grads_re):
|
||||||
self.assertAllClose(g, g_re)
|
self.assertAllClose(g, g_re)
|
||||||
|
|
||||||
|
# Regression test for wrapping sequence outputting functions.
|
||||||
|
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
|
||||||
|
test_input)
|
||||||
|
grads_re = self.evaluate(grads_re)
|
||||||
|
grads = self.evaluate(grads)
|
||||||
|
for g, g_re in zip(grads, grads_re):
|
||||||
|
self.assertAllClose(g, g_re)
|
||||||
|
|
||||||
|
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
|
||||||
|
test_var)
|
||||||
|
grads_re = self.evaluate(grads_re)
|
||||||
|
grads = self.evaluate(grads)
|
||||||
|
for g, g_re in zip(grads, grads_re):
|
||||||
|
self.assertAllClose(g, g_re)
|
||||||
|
|
||||||
|
|
||||||
class GradPassThroughTest(test_util.TensorFlowTestCase):
|
class GradPassThroughTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user