diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 87fad9b9b1f..5be9a533dc4 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1594,6 +1594,35 @@ class BackpropTest(test.TestCase, parameterized.TestCase): self.assertIn('gradient_tape/my_scope/', op.name) self.assertEqual(num_sin_ops_found, 2) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testRecomputeGradWithDifferentShape(self): + + @custom_gradient.recompute_grad + def outer(x): + return [x[0] + 1, x[1] + 1] + + x = [ + variables.Variable([1.0, 2.0], name='a'), + variables.Variable(1.0, name='b') + ] + with backprop.GradientTape(): + y = outer(x) + self.assertAllEqual(y[0], [2.0, 3.0]) + self.assertAllEqual(y[1], 2.0) + + @custom_gradient.recompute_grad + def outer_dict(x): + for key in x.keys(): + x[key] = x[key] + 1 + return x + + x = {x[0].ref(): x[0], x[1].ref(): x[1]} + with backprop.GradientTape(): + y = outer_dict(x) + y = list(y.values()) + self.assertAllEqual(y[0], [2.0, 3.0]) + self.assertAllEqual(y[1], 2.0) + @test_util.assert_no_new_pyobjects_executing_eagerly def testRecomputeGradWithNestedFunctionAndWhileLoop(self): diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 3e38f68a0f7..5d7f605b884 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -524,7 +524,7 @@ def recompute_grad(f): # Gradient calculation for reverse mode autodiff. variables = grad_kwargs.get("variables") with backprop.GradientTape() as t: - id_args = [gen_array_ops.identity(x) for x in args] + id_args = nest.map_structure(gen_array_ops.identity, args) t.watch(id_args) if variables is not None: t.watch(variables)