Merge pull request #44373 from fsx950223:fix_warning
PiperOrigin-RevId: 339946530 Change-Id: I9ff4d4a380b6f2c2181edfc76cbb41af09abf0e2
This commit is contained in:
commit
662c0c2ece
@ -1594,6 +1594,35 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn('gradient_tape/my_scope/', op.name)
|
self.assertIn('gradient_tape/my_scope/', op.name)
|
||||||
self.assertEqual(num_sin_ops_found, 2)
|
self.assertEqual(num_sin_ops_found, 2)
|
||||||
|
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def testRecomputeGradWithDifferentShape(self):
|
||||||
|
|
||||||
|
@custom_gradient.recompute_grad
|
||||||
|
def outer(x):
|
||||||
|
return [x[0] + 1, x[1] + 1]
|
||||||
|
|
||||||
|
x = [
|
||||||
|
variables.Variable([1.0, 2.0], name='a'),
|
||||||
|
variables.Variable(1.0, name='b')
|
||||||
|
]
|
||||||
|
with backprop.GradientTape():
|
||||||
|
y = outer(x)
|
||||||
|
self.assertAllEqual(y[0], [2.0, 3.0])
|
||||||
|
self.assertAllEqual(y[1], 2.0)
|
||||||
|
|
||||||
|
@custom_gradient.recompute_grad
|
||||||
|
def outer_dict(x):
|
||||||
|
for key in x.keys():
|
||||||
|
x[key] = x[key] + 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = {x[0].ref(): x[0], x[1].ref(): x[1]}
|
||||||
|
with backprop.GradientTape():
|
||||||
|
y = outer_dict(x)
|
||||||
|
y = list(y.values())
|
||||||
|
self.assertAllEqual(y[0], [2.0, 3.0])
|
||||||
|
self.assertAllEqual(y[1], 2.0)
|
||||||
|
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
|
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
|
||||||
|
|
||||||
|
@ -524,7 +524,7 @@ def recompute_grad(f):
|
|||||||
# Gradient calculation for reverse mode autodiff.
|
# Gradient calculation for reverse mode autodiff.
|
||||||
variables = grad_kwargs.get("variables")
|
variables = grad_kwargs.get("variables")
|
||||||
with backprop.GradientTape() as t:
|
with backprop.GradientTape() as t:
|
||||||
id_args = [gen_array_ops.identity(x) for x in args]
|
id_args = nest.map_structure(gen_array_ops.identity, args)
|
||||||
t.watch(id_args)
|
t.watch(id_args)
|
||||||
if variables is not None:
|
if variables is not None:
|
||||||
t.watch(variables)
|
t.watch(variables)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user