Merge pull request #44373 from fsx950223:fix_warning
PiperOrigin-RevId: 339946530 Change-Id: I9ff4d4a380b6f2c2181edfc76cbb41af09abf0e2
This commit is contained in:
commit
662c0c2ece
@ -1594,6 +1594,35 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('gradient_tape/my_scope/', op.name)
|
||||
self.assertEqual(num_sin_ops_found, 2)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testRecomputeGradWithDifferentShape(self):
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
def outer(x):
|
||||
return [x[0] + 1, x[1] + 1]
|
||||
|
||||
x = [
|
||||
variables.Variable([1.0, 2.0], name='a'),
|
||||
variables.Variable(1.0, name='b')
|
||||
]
|
||||
with backprop.GradientTape():
|
||||
y = outer(x)
|
||||
self.assertAllEqual(y[0], [2.0, 3.0])
|
||||
self.assertAllEqual(y[1], 2.0)
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
def outer_dict(x):
|
||||
for key in x.keys():
|
||||
x[key] = x[key] + 1
|
||||
return x
|
||||
|
||||
x = {x[0].ref(): x[0], x[1].ref(): x[1]}
|
||||
with backprop.GradientTape():
|
||||
y = outer_dict(x)
|
||||
y = list(y.values())
|
||||
self.assertAllEqual(y[0], [2.0, 3.0])
|
||||
self.assertAllEqual(y[1], 2.0)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
|
||||
|
||||
|
@ -524,7 +524,7 @@ def recompute_grad(f):
|
||||
# Gradient calculation for reverse mode autodiff.
|
||||
variables = grad_kwargs.get("variables")
|
||||
with backprop.GradientTape() as t:
|
||||
id_args = [gen_array_ops.identity(x) for x in args]
|
||||
id_args = nest.map_structure(gen_array_ops.identity, args)
|
||||
t.watch(id_args)
|
||||
if variables is not None:
|
||||
t.watch(variables)
|
||||
|
Loading…
x
Reference in New Issue
Block a user