diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 214ab9a393b..33b1651a040 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -238,6 +238,9 @@ def _graph_mode_decorator(f, *args, **kwargs): original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) + + original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] + # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9b24473f57a..9d6ac46c049 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -1033,6 +1033,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase): self.assertAllEqual(g.eval(), [2.0]) self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) + def testRecursiveCustomGradient(self): + @custom_gradient.custom_gradient + def F(x): + out = core_layers.dense(x, 3, use_bias=False) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + self.assertEqual(1, len(variables)) + grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) + return grads[0], [array_ops.ones((4, 3))] + + return out, Grad + + @custom_gradient.custom_gradient + def DoubleF(x): + out = F(x) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + self.assertEqual(1, len(variables)) + grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) + return grads[0], [array_ops.ones((4, 3))] + + return out, Grad + with ops.Graph().as_default(): + x = array_ops.ones((2, 4)) + with variable_scope.variable_scope("f", use_resource=True) as vs: + y = DoubleF(x) + all_vars = vs.global_variables() + assert len(all_vars) == 1 + grads = gradients.gradients(y, [x, all_vars[0]]) + for g in grads: + self.assertIsNotNone(g) + with session.Session() as sess: + self.evaluate(variables.global_variables_initializer()) + dw = sess.run(math_ops.reduce_sum(grads[1])) + self.assertEqual(12., dw) + class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):