Merge pull request #42048 from abhichou4:tests/accumulator
PiperOrigin-RevId: 326740421 Change-Id: I7e3ae1089c036e5b91cccca2e92b25ce620f9779
This commit is contained in:
commit
7b56de0366
@ -1041,6 +1041,27 @@ class BatchTests(test.TestCase, parameterized.TestCase):
|
||||
z = x * y
|
||||
self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
|
||||
|
||||
@parameterized.named_parameters([("ForwardPropFirst", True),
|
||||
("TapeFirst", False)])
|
||||
def testBatchBackwardOverForward(self, forward_prop_first):
|
||||
x = constant_op.constant(1.)
|
||||
tangents = random_ops.random_normal(shape=[10], seed=1)
|
||||
expected = [-t * math_ops.cos(1.) for t in tangents]
|
||||
if forward_prop_first:
|
||||
batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
|
||||
gradient_tape = backprop.GradientTape(persistent=True)
|
||||
else:
|
||||
gradient_tape = backprop.GradientTape(persistent=True)
|
||||
batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
|
||||
with gradient_tape as tape:
|
||||
with batch_acc as acc:
|
||||
tape.watch(x)
|
||||
y = math_ops.cos(x)
|
||||
self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),)))
|
||||
jvps = acc.jvp(y)
|
||||
d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps]
|
||||
self.assertAllClose(expected, d2y_dx2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||
|
Loading…
x
Reference in New Issue
Block a user