GradientTape.jacobian/batch_jacobian: allow calls inside the GradientTape's scope
Since they call GradientTape.gradient eventually anyway, they'll inherit the warning about persistent tapes recording themselves taking extra memory. Fixes #34260. Fixes #41510. PiperOrigin-RevId: 357251843 Change-Id: I913d5dee058b01a1e3fbb830c246d4e7a2133cd5
This commit is contained in:
parent
2300294ad4
commit
8e4ba814c2
@ -873,6 +873,18 @@ class GradientTape(object):
|
|||||||
tape.pop_tape(self._tape)
|
tape.pop_tape(self._tape)
|
||||||
self._recording = False
|
self._recording = False
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
|
def _ensure_recording(self):
|
||||||
|
"""Ensures that this tape is recording."""
|
||||||
|
if not self._recording:
|
||||||
|
try:
|
||||||
|
self._push_tape()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._pop_tape()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
def watch(self, tensor):
|
def watch(self, tensor):
|
||||||
"""Ensures that `tensor` is being traced by this tape.
|
"""Ensures that `tensor` is being traced by this tape.
|
||||||
|
|
||||||
@ -1144,14 +1156,12 @@ class GradientTape(object):
|
|||||||
target_shape = array_ops.shape(target)
|
target_shape = array_ops.shape(target)
|
||||||
# Note that we push and pop the tape here and below. This is needed since we
|
# Note that we push and pop the tape here and below. This is needed since we
|
||||||
# need gradients through the enclosed operations.
|
# need gradients through the enclosed operations.
|
||||||
self._push_tape()
|
with self._ensure_recording():
|
||||||
target = array_ops.reshape(target, [-1])
|
target = array_ops.reshape(target, [-1])
|
||||||
self._pop_tape()
|
|
||||||
|
|
||||||
def loop_fn(i):
|
def loop_fn(i):
|
||||||
self._push_tape()
|
with self._ensure_recording():
|
||||||
y = array_ops.gather(target, i)
|
y = array_ops.gather(target, i)
|
||||||
self._pop_tape()
|
|
||||||
return self.gradient(y, flat_sources,
|
return self.gradient(y, flat_sources,
|
||||||
unconnected_gradients=unconnected_gradients)
|
unconnected_gradients=unconnected_gradients)
|
||||||
|
|
||||||
@ -1285,16 +1295,14 @@ class GradientTape(object):
|
|||||||
# Flatten target to 2-D.
|
# Flatten target to 2-D.
|
||||||
# Note that we push and pop the tape here and below. This is needed since we
|
# Note that we push and pop the tape here and below. This is needed since we
|
||||||
# need gradients through the enclosed operations.
|
# need gradients through the enclosed operations.
|
||||||
self._push_tape()
|
with self._ensure_recording():
|
||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
[check_ops.assert_equal(batch_size, source_shape[0])]):
|
[check_ops.assert_equal(batch_size, source_shape[0])]):
|
||||||
target = array_ops.reshape(target, [batch_size, target_row_size])
|
target = array_ops.reshape(target, [batch_size, target_row_size])
|
||||||
self._pop_tape()
|
|
||||||
|
|
||||||
def loop_fn(i):
|
def loop_fn(i):
|
||||||
self._push_tape()
|
with self._ensure_recording():
|
||||||
y = array_ops.gather(target, i, axis=1)
|
y = array_ops.gather(target, i, axis=1)
|
||||||
self._pop_tape()
|
|
||||||
return self.gradient(y, source,
|
return self.gradient(y, source,
|
||||||
unconnected_gradients=unconnected_gradients)
|
unconnected_gradients=unconnected_gradients)
|
||||||
|
|
||||||
|
@ -854,6 +854,24 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
RuntimeError, 'A non-persistent GradientTape can only'):
|
RuntimeError, 'A non-persistent GradientTape can only'):
|
||||||
g.jacobian(y, [x])
|
g.jacobian(y, [x])
|
||||||
|
|
||||||
|
@test_util.assert_no_new_tensors
|
||||||
|
def testJacobianInsideGradientTapeScope(self):
|
||||||
|
with backprop.GradientTape() as g:
|
||||||
|
x = constant_op.constant(3.0)
|
||||||
|
g.watch(x)
|
||||||
|
y = x * x
|
||||||
|
z = y * y
|
||||||
|
self.assertAllClose(4. * 3. ** 3., g.jacobian(z, x))
|
||||||
|
|
||||||
|
@test_util.assert_no_new_tensors
|
||||||
|
def testBatchJacobianInsideGradientTapeScope(self):
|
||||||
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
x = constant_op.constant([[3.0]])
|
||||||
|
g.watch(x)
|
||||||
|
y = x * x
|
||||||
|
z = y * y
|
||||||
|
self.assertAllClose([[[4. * 3. ** 3.]]], g.batch_jacobian(z, x))
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
def testGradientTapeBatchJacobianCalledMultipleTimes(self):
|
def testGradientTapeBatchJacobianCalledMultipleTimes(self):
|
||||||
with backprop.GradientTape() as g:
|
with backprop.GradientTape() as g:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user