Merge pull request #42048 from abhichou4:tests/accumulator

PiperOrigin-RevId: 326740421
Change-Id: I7e3ae1089c036e5b91cccca2e92b25ce620f9779
This commit is contained in:
TensorFlower Gardener 2020-08-14 15:27:11 -07:00
commit 7b56de0366

View File

@ -1041,6 +1041,27 @@ class BatchTests(test.TestCase, parameterized.TestCase):
z = x * y
self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
@parameterized.named_parameters([("ForwardPropFirst", True),
("TapeFirst", False)])
def testBatchBackwardOverForward(self, forward_prop_first):
x = constant_op.constant(1.)
tangents = random_ops.random_normal(shape=[10], seed=1)
expected = [-t * math_ops.cos(1.) for t in tangents]
if forward_prop_first:
batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
gradient_tape = backprop.GradientTape(persistent=True)
else:
gradient_tape = backprop.GradientTape(persistent=True)
batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
with gradient_tape as tape:
with batch_acc as acc:
tape.watch(x)
y = math_ops.cos(x)
self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),)))
jvps = acc.jvp(y)
d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps]
self.assertAllClose(expected, d2y_dx2)
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.