GradientTape.jacobian/batch_jacobian: allow calls inside the GradientTape's scope
Since they call GradientTape.gradient eventually anyway, they'll inherit the warning about persistent tapes recording themselves taking extra memory. Fixes #34260. Fixes #41510. PiperOrigin-RevId: 357251843 Change-Id: I913d5dee058b01a1e3fbb830c246d4e7a2133cd5
This commit is contained in:
parent
2300294ad4
commit
8e4ba814c2
tensorflow/python/eager
@ -873,6 +873,18 @@ class GradientTape(object):
|
||||
tape.pop_tape(self._tape)
|
||||
self._recording = False
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def _ensure_recording(self):
|
||||
"""Ensures that this tape is recording."""
|
||||
if not self._recording:
|
||||
try:
|
||||
self._push_tape()
|
||||
yield
|
||||
finally:
|
||||
self._pop_tape()
|
||||
else:
|
||||
yield
|
||||
|
||||
def watch(self, tensor):
|
||||
"""Ensures that `tensor` is being traced by this tape.
|
||||
|
||||
@ -1144,14 +1156,12 @@ class GradientTape(object):
|
||||
target_shape = array_ops.shape(target)
|
||||
# Note that we push and pop the tape here and below. This is needed since we
|
||||
# need gradients through the enclosed operations.
|
||||
self._push_tape()
|
||||
target = array_ops.reshape(target, [-1])
|
||||
self._pop_tape()
|
||||
with self._ensure_recording():
|
||||
target = array_ops.reshape(target, [-1])
|
||||
|
||||
def loop_fn(i):
|
||||
self._push_tape()
|
||||
y = array_ops.gather(target, i)
|
||||
self._pop_tape()
|
||||
with self._ensure_recording():
|
||||
y = array_ops.gather(target, i)
|
||||
return self.gradient(y, flat_sources,
|
||||
unconnected_gradients=unconnected_gradients)
|
||||
|
||||
@ -1285,16 +1295,14 @@ class GradientTape(object):
|
||||
# Flatten target to 2-D.
|
||||
# Note that we push and pop the tape here and below. This is needed since we
|
||||
# need gradients through the enclosed operations.
|
||||
self._push_tape()
|
||||
with ops.control_dependencies(
|
||||
[check_ops.assert_equal(batch_size, source_shape[0])]):
|
||||
target = array_ops.reshape(target, [batch_size, target_row_size])
|
||||
self._pop_tape()
|
||||
with self._ensure_recording():
|
||||
with ops.control_dependencies(
|
||||
[check_ops.assert_equal(batch_size, source_shape[0])]):
|
||||
target = array_ops.reshape(target, [batch_size, target_row_size])
|
||||
|
||||
def loop_fn(i):
|
||||
self._push_tape()
|
||||
y = array_ops.gather(target, i, axis=1)
|
||||
self._pop_tape()
|
||||
with self._ensure_recording():
|
||||
y = array_ops.gather(target, i, axis=1)
|
||||
return self.gradient(y, source,
|
||||
unconnected_gradients=unconnected_gradients)
|
||||
|
||||
|
@ -854,6 +854,24 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
RuntimeError, 'A non-persistent GradientTape can only'):
|
||||
g.jacobian(y, [x])
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testJacobianInsideGradientTapeScope(self):
|
||||
with backprop.GradientTape() as g:
|
||||
x = constant_op.constant(3.0)
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
z = y * y
|
||||
self.assertAllClose(4. * 3. ** 3., g.jacobian(z, x))
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testBatchJacobianInsideGradientTapeScope(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
x = constant_op.constant([[3.0]])
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
z = y * y
|
||||
self.assertAllClose([[[4. * 3. ** 3.]]], g.batch_jacobian(z, x))
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testGradientTapeBatchJacobianCalledMultipleTimes(self):
|
||||
with backprop.GradientTape() as g:
|
||||
|
Loading…
Reference in New Issue
Block a user