GradientTape.jacobian/batch_jacobian: allow calls inside the GradientTape's scope

Since they call GradientTape.gradient eventually anyway, they'll inherit the warning about persistent tapes recording themselves taking extra memory.

Fixes .
Fixes .

PiperOrigin-RevId: 357251843
Change-Id: I913d5dee058b01a1e3fbb830c246d4e7a2133cd5
This commit is contained in:
Allen Lavoie 2021-02-12 13:01:50 -08:00 committed by TensorFlower Gardener
parent 2300294ad4
commit 8e4ba814c2
2 changed files with 40 additions and 14 deletions
tensorflow/python/eager

View File

@ -873,6 +873,18 @@ class GradientTape(object):
tape.pop_tape(self._tape)
self._recording = False
@tf_contextlib.contextmanager
def _ensure_recording(self):
"""Ensures that this tape is recording."""
if not self._recording:
try:
self._push_tape()
yield
finally:
self._pop_tape()
else:
yield
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
@ -1144,14 +1156,12 @@ class GradientTape(object):
target_shape = array_ops.shape(target)
# Note that we push and pop the tape here and below. This is needed since we
# need gradients through the enclosed operations.
self._push_tape()
target = array_ops.reshape(target, [-1])
self._pop_tape()
with self._ensure_recording():
target = array_ops.reshape(target, [-1])
def loop_fn(i):
self._push_tape()
y = array_ops.gather(target, i)
self._pop_tape()
with self._ensure_recording():
y = array_ops.gather(target, i)
return self.gradient(y, flat_sources,
unconnected_gradients=unconnected_gradients)
@ -1285,16 +1295,14 @@ class GradientTape(object):
# Flatten target to 2-D.
# Note that we push and pop the tape here and below. This is needed since we
# need gradients through the enclosed operations.
self._push_tape()
with ops.control_dependencies(
[check_ops.assert_equal(batch_size, source_shape[0])]):
target = array_ops.reshape(target, [batch_size, target_row_size])
self._pop_tape()
with self._ensure_recording():
with ops.control_dependencies(
[check_ops.assert_equal(batch_size, source_shape[0])]):
target = array_ops.reshape(target, [batch_size, target_row_size])
def loop_fn(i):
self._push_tape()
y = array_ops.gather(target, i, axis=1)
self._pop_tape()
with self._ensure_recording():
y = array_ops.gather(target, i, axis=1)
return self.gradient(y, source,
unconnected_gradients=unconnected_gradients)

View File

@ -854,6 +854,24 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
RuntimeError, 'A non-persistent GradientTape can only'):
g.jacobian(y, [x])
@test_util.assert_no_new_tensors
def testJacobianInsideGradientTapeScope(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
g.watch(x)
y = x * x
z = y * y
self.assertAllClose(4. * 3. ** 3., g.jacobian(z, x))
@test_util.assert_no_new_tensors
def testBatchJacobianInsideGradientTapeScope(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant([[3.0]])
g.watch(x)
y = x * x
z = y * y
self.assertAllClose([[[4. * 3. ** 3.]]], g.batch_jacobian(z, x))
@test_util.assert_no_new_tensors
def testGradientTapeBatchJacobianCalledMultipleTimes(self):
with backprop.GradientTape() as g: