Fix tf.custom_gradient spurious error when the grad_fn uses varargs.

The documentation recommends using the signature `g(*grad_ys, variables=None)`
but I get an error with that signature. Same with `g(dy, *, variables=None)`.
The error is that the code that triggers the error does not look for `variables`
in the kwonlyargs.

PiperOrigin-RevId: 318966853
Change-Id: Icb06a9e47499bc62d3dc88a24847cda81f98a543
This commit is contained in:
A. Unique TensorFlower 2020-06-29 23:45:38 -07:00 committed by TensorFlower Gardener
parent 3eb7f91013
commit 6e69fc67a0
2 changed files with 28 additions and 0 deletions

View File

@ -372,6 +372,7 @@ def _graph_mode_decorator(f, args, kwargs):
grad_argspec = tf_inspect.getfullargspec(grad_fn)
variables_in_signature = ("variables" in grad_argspec.args or
"variables" in grad_argspec.kwonlyargs or
grad_argspec.varkw)
if variables and not variables_in_signature:
raise TypeError(
@ -440,6 +441,7 @@ def _eager_mode_decorator(f, args, kwargs):
]
grad_argspec = tf_inspect.getfullargspec(grad_fn)
if (variables and ("variables" not in grad_argspec.args) and
("variables" not in grad_argspec.kwonlyargs) and
not grad_argspec.varkw):
raise TypeError(
"@tf.custom_gradient grad_fn must accept keyword argument 'variables', "

View File

@ -1293,6 +1293,32 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with variable_scope.variable_scope("vs2", use_resource=False):
FResource(x)
@parameterized.parameters(True, False)
def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs):
with context.eager_mode():
x_captured = variables.Variable(3.) # Used by FuncMult
@custom_gradient.custom_gradient
def FuncMult(x):
def ActualGrad(dy, variables): # pylint: disable=redefined-outer-name
self.assertLen(variables, 1)
self.assertIs(variables[0], x_captured)
x_captured_grad = 5. * x * dy
return (4. * x_captured * dy, [x_captured_grad])
# Define the returned GradMult, using varargs; "variables" is kwonlyarg
if anonymous_varargs:
def GradMult(dy, *, variables=None): # pylint: disable=redefined-outer-name
return ActualGrad(dy, variables)
else:
def GradMult(*dys, variables=None): # pylint: disable=redefined-outer-name
return ActualGrad(dys[0], variables)
return x * x_captured, GradMult
x = variables.Variable(6.)
with backprop.GradientTape(persistent=True) as g:
y = FuncMult(x)
self.assertAllEqual(g.gradient(y, x), 4. * 3.)
def testWithNumpyInputs(self):
with context.eager_mode():