diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index f42e000b4b9..185f0d04782 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -372,6 +372,7 @@ def _graph_mode_decorator(f, args, kwargs): grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or + "variables" in grad_argspec.kwonlyargs or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError( @@ -440,6 +441,7 @@ def _eager_mode_decorator(f, args, kwargs): ] grad_argspec = tf_inspect.getfullargspec(grad_fn) if (variables and ("variables" not in grad_argspec.args) and + ("variables" not in grad_argspec.kwonlyargs) and not grad_argspec.varkw): raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 81140af6e79..78fbcdd6e6f 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -1293,6 +1293,32 @@ class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): with variable_scope.variable_scope("vs2", use_resource=False): FResource(x) + @parameterized.parameters(True, False) + def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs): + with context.eager_mode(): + x_captured = variables.Variable(3.) # Used by FuncMult + @custom_gradient.custom_gradient + def FuncMult(x): + def ActualGrad(dy, variables): # pylint: disable=redefined-outer-name + self.assertLen(variables, 1) + self.assertIs(variables[0], x_captured) + x_captured_grad = 5. * x * dy + return (4. * x_captured * dy, [x_captured_grad]) + # Define the returned GradMult, using varargs; "variables" is kwonlyarg + if anonymous_varargs: + def GradMult(dy, *, variables=None): # pylint: disable=redefined-outer-name + return ActualGrad(dy, variables) + else: + def GradMult(*dys, variables=None): # pylint: disable=redefined-outer-name + return ActualGrad(dys[0], variables) + + return x * x_captured, GradMult + + x = variables.Variable(6.) + with backprop.GradientTape(persistent=True) as g: + y = FuncMult(x) + self.assertAllEqual(g.gradient(y, x), 4. * 3.) + def testWithNumpyInputs(self): with context.eager_mode():