Fix tf.custom_gradient spurious error when the grad_fn uses varargs.
The documentation recommends using the signature `g(*grad_ys, variables=None)` but I get an error with that signature. Same with `g(dy, *, variables=None)`. The error is that the code that triggers the error does not look for `variables` in the kwonlyargs. PiperOrigin-RevId: 318966853 Change-Id: Icb06a9e47499bc62d3dc88a24847cda81f98a543
This commit is contained in:
parent
3eb7f91013
commit
6e69fc67a0
@ -372,6 +372,7 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
|
||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||
variables_in_signature = ("variables" in grad_argspec.args or
|
||||
"variables" in grad_argspec.kwonlyargs or
|
||||
grad_argspec.varkw)
|
||||
if variables and not variables_in_signature:
|
||||
raise TypeError(
|
||||
@ -440,6 +441,7 @@ def _eager_mode_decorator(f, args, kwargs):
|
||||
]
|
||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||
if (variables and ("variables" not in grad_argspec.args) and
|
||||
("variables" not in grad_argspec.kwonlyargs) and
|
||||
not grad_argspec.varkw):
|
||||
raise TypeError(
|
||||
"@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
|
||||
|
@ -1293,6 +1293,32 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with variable_scope.variable_scope("vs2", use_resource=False):
|
||||
FResource(x)
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs):
|
||||
with context.eager_mode():
|
||||
x_captured = variables.Variable(3.) # Used by FuncMult
|
||||
@custom_gradient.custom_gradient
|
||||
def FuncMult(x):
|
||||
def ActualGrad(dy, variables): # pylint: disable=redefined-outer-name
|
||||
self.assertLen(variables, 1)
|
||||
self.assertIs(variables[0], x_captured)
|
||||
x_captured_grad = 5. * x * dy
|
||||
return (4. * x_captured * dy, [x_captured_grad])
|
||||
# Define the returned GradMult, using varargs; "variables" is kwonlyarg
|
||||
if anonymous_varargs:
|
||||
def GradMult(dy, *, variables=None): # pylint: disable=redefined-outer-name
|
||||
return ActualGrad(dy, variables)
|
||||
else:
|
||||
def GradMult(*dys, variables=None): # pylint: disable=redefined-outer-name
|
||||
return ActualGrad(dys[0], variables)
|
||||
|
||||
return x * x_captured, GradMult
|
||||
|
||||
x = variables.Variable(6.)
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
y = FuncMult(x)
|
||||
self.assertAllEqual(g.gradient(y, x), 4. * 3.)
|
||||
|
||||
def testWithNumpyInputs(self):
|
||||
with context.eager_mode():
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user