Remove @test_util.run_deprecated_v1 in gradient_correctness_test.py

Also add `persistent` arg to test_util.AbstractGradientTape

PiperOrigin-RevId: 324306126
Change-Id: I4f30cff141d725b9c0dd9d2c71e353df4923ca2e
This commit is contained in:
Kibeom Kim 2020-07-31 16:24:03 -07:00 committed by TensorFlower Gardener
parent 5594fde93b
commit 29b576a718
2 changed files with 88 additions and 65 deletions

View File

@ -3318,16 +3318,16 @@ class AbstractGradientTape:
duplicating tests. duplicating tests.
""" """
def __init__(self, use_tape): def __init__(self, use_tape, persistent=False):
self._use_tape = use_tape self._use_tape = use_tape
self._persistent = persistent
def __enter__(self): def __enter__(self):
if self._use_tape: if self._use_tape:
self._tape_impl = backprop.GradientTape() self._tape_impl = backprop.GradientTape(persistent=self._persistent)
else: else:
self._tape_impl = _fake_gradient_tape_context_manager() self._tape_impl = _fake_gradient_tape_context_manager()
return self._tape_impl.__enter__() return self._tape_impl.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._tape_impl.__exit__(exc_type, exc_val, exc_tb) self._tape_impl.__exit__(exc_type, exc_val, exc_tb)

View File

@ -25,20 +25,21 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class GradientCorrectnessTest(test.TestCase, parameterized.TestCase): class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testMultipleOutputChainedGradients(self): def testMultipleOutputChainedGradients(self, use_tape):
with self.cached_session() as sess: with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
x = constant_op.constant(1.0, dtype=dtypes.float32) x = constant_op.constant(1.0, dtype=dtypes.float32)
tape.watch(x)
yexp = math_ops.exp(x) yexp = math_ops.exp(x)
yexplog = math_ops.log(yexp) yexplog = math_ops.log(yexp)
grads = gradients_impl.gradients([yexp, yexplog], [x]) grads = tape.gradient([yexp, yexplog], [x])
grad_vals = self.evaluate(grads) grad_vals = self.evaluate(grads)
exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32) exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32)
# [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1 # [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
@ -52,72 +53,94 @@ class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
dx_dx = tape.gradient(x, x) dx_dx = tape.gradient(x, x)
self.assertAllClose(1., self.evaluate(dx_dx)) self.assertAllClose(1., self.evaluate(dx_dx))
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testIntegerIdentityGradient(self): def testIntegerIdentityGradient(self, use_tape):
x = constant_op.constant(3) x = constant_op.constant(3)
dx_dx, = gradients_impl.gradients(x, x) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
with self.cached_session() as sess: tape.watch(x)
self.assertAllClose(1, self.evaluate(dx_dx)) dx_dx = tape.gradient(x, x)
self.assertAllClose(1, self.evaluate(dx_dx))
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientWithIntegerPath(self): def testGradientWithIntegerPath(self, use_tape):
x = constant_op.constant([3.9, 4.1]) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32) x = constant_op.constant([3.9, 4.1])
y = x * k tape.watch(x)
dy_dx, = gradients_impl.gradients(y, x)
with self.cached_session() as sess: k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
y = x * k
dy_dx = tape.gradient(y, x)
self.assertAllClose([3., 4.], self.evaluate(dy_dx)) self.assertAllClose([3., 4.], self.evaluate(dy_dx))
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient1(self): def testNoIntegerGradient1(self, use_tape):
x = constant_op.constant([3.9, 4.1]) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32) x = constant_op.constant([3.9, 4.1])
y = k * k tape.watch(x)
dy_dx, = gradients_impl.gradients(y, x)
self.assertIsNone(dy_dx)
@test_util.run_deprecated_v1 k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
def testNoIntegerGradient2(self): y = k * k
k = constant_op.constant([3, 4]) dy_dx = tape.gradient(y, x)
x = math_ops.cast(k, dtypes.float32) self.assertIsNone(dy_dx)
y = x * x
dy_dk, = gradients_impl.gradients(y, k)
self.assertIsNone(dy_dk)
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient3(self): def testNoIntegerGradient2(self, use_tape):
k = constant_op.constant([3, 4]) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
m = k * k k = constant_op.constant([3, 4])
dm_dk, = gradients_impl.gradients(m, k) x = math_ops.cast(k, dtypes.float32)
self.assertIsNone(dm_dk) tape.watch([k, x])
@test_util.run_deprecated_v1 y = x * x
def testNoIntegerGradient4(self): dy_dk = tape.gradient(y, k)
k = constant_op.constant([3, 4]) self.assertIsNone(dy_dk)
m = k * k * k
dm_dk, = gradients_impl.gradients(m, k)
self.assertIsNone(dm_dk)
@test_util.run_deprecated_v1 @parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient5(self): def testNoIntegerGradient3(self, use_tape):
k = constant_op.constant([3, 4]) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
m = k * k k = constant_op.constant([3, 4])
n = m * m tape.watch(k)
dn_dk, = gradients_impl.gradients(n, k)
self.assertIsNone(dn_dk)
@test_util.run_deprecated_v1 m = k * k
def testNoIntegerGradient6(self): dm_dk = tape.gradient(m, k)
k = constant_op.constant(3) self.assertIsNone(dm_dk)
x = math_ops.cast(k, dtypes.float32)
grad_1, = gradients_impl.gradients(k * k, k) @parameterized.parameters(set((True, context.executing_eagerly())))
grad_2, = gradients_impl.gradients(x * x, k) def testNoIntegerGradient4(self, use_tape):
grad_3, = gradients_impl.gradients(math_ops.square(k), k) with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
grad_4, = gradients_impl.gradients(math_ops.square(x), k) k = constant_op.constant([3, 4])
self.assertIsNone(grad_1) tape.watch(k)
self.assertIsNone(grad_2)
self.assertIsNone(grad_3) m = k * k * k
self.assertIsNone(grad_4) dm_dk = tape.gradient(m, k)
self.assertIsNone(dm_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient5(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = constant_op.constant([3, 4])
tape.watch(k)
m = k * k
n = m * m
dn_dk = tape.gradient(n, k)
self.assertIsNone(dn_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient6(self, use_tape):
with test_util.AbstractGradientTape(
use_tape=use_tape, persistent=True) as tape:
k = constant_op.constant(3)
tape.watch(k)
x = math_ops.cast(k, dtypes.float32)
grad_1 = tape.gradient(k * k, k)
grad_2 = tape.gradient(x * x, k)
grad_3 = tape.gradient(math_ops.square(k), k)
grad_4 = tape.gradient(math_ops.square(x), k)
self.assertIsNone(grad_1)
self.assertIsNone(grad_2)
self.assertIsNone(grad_3)
self.assertIsNone(grad_4)
if __name__ == '__main__': if __name__ == '__main__':