Remove @test_util.run_deprecated_v1 in gradient_correctness_test.py

Also add `persistent` arg to test_util.AbstractGradientTape

PiperOrigin-RevId: 324306126
Change-Id: I4f30cff141d725b9c0dd9d2c71e353df4923ca2e
This commit is contained in:
Kibeom Kim 2020-07-31 16:24:03 -07:00 committed by TensorFlower Gardener
parent 5594fde93b
commit 29b576a718
2 changed files with 88 additions and 65 deletions

View File

@ -3318,16 +3318,16 @@ class AbstractGradientTape:
duplicating tests.
"""
def __init__(self, use_tape):
def __init__(self, use_tape, persistent=False):
self._use_tape = use_tape
self._persistent = persistent
def __enter__(self):
if self._use_tape:
self._tape_impl = backprop.GradientTape()
self._tape_impl = backprop.GradientTape(persistent=self._persistent)
else:
self._tape_impl = _fake_gradient_tape_context_manager()
return self._tape_impl.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self._tape_impl.__exit__(exc_type, exc_val, exc_tb)

View File

@ -25,20 +25,21 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
@test_util.run_deprecated_v1
def testMultipleOutputChainedGradients(self):
with self.cached_session() as sess:
@parameterized.parameters(set((True, context.executing_eagerly())))
def testMultipleOutputChainedGradients(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
x = constant_op.constant(1.0, dtype=dtypes.float32)
tape.watch(x)
yexp = math_ops.exp(x)
yexplog = math_ops.log(yexp)
grads = gradients_impl.gradients([yexp, yexplog], [x])
grads = tape.gradient([yexp, yexplog], [x])
grad_vals = self.evaluate(grads)
exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32)
# [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
@ -52,72 +53,94 @@ class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
dx_dx = tape.gradient(x, x)
self.assertAllClose(1., self.evaluate(dx_dx))
@test_util.run_deprecated_v1
def testIntegerIdentityGradient(self):
@parameterized.parameters(set((True, context.executing_eagerly())))
def testIntegerIdentityGradient(self, use_tape):
x = constant_op.constant(3)
dx_dx, = gradients_impl.gradients(x, x)
with self.cached_session() as sess:
self.assertAllClose(1, self.evaluate(dx_dx))
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
tape.watch(x)
dx_dx = tape.gradient(x, x)
self.assertAllClose(1, self.evaluate(dx_dx))
@test_util.run_deprecated_v1
def testGradientWithIntegerPath(self):
x = constant_op.constant([3.9, 4.1])
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
y = x * k
dy_dx, = gradients_impl.gradients(y, x)
with self.cached_session() as sess:
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradientWithIntegerPath(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
x = constant_op.constant([3.9, 4.1])
tape.watch(x)
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
y = x * k
dy_dx = tape.gradient(y, x)
self.assertAllClose([3., 4.], self.evaluate(dy_dx))
@test_util.run_deprecated_v1
def testNoIntegerGradient1(self):
x = constant_op.constant([3.9, 4.1])
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
y = k * k
dy_dx, = gradients_impl.gradients(y, x)
self.assertIsNone(dy_dx)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient1(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
x = constant_op.constant([3.9, 4.1])
tape.watch(x)
@test_util.run_deprecated_v1
def testNoIntegerGradient2(self):
k = constant_op.constant([3, 4])
x = math_ops.cast(k, dtypes.float32)
y = x * x
dy_dk, = gradients_impl.gradients(y, k)
self.assertIsNone(dy_dk)
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
y = k * k
dy_dx = tape.gradient(y, x)
self.assertIsNone(dy_dx)
@test_util.run_deprecated_v1
def testNoIntegerGradient3(self):
k = constant_op.constant([3, 4])
m = k * k
dm_dk, = gradients_impl.gradients(m, k)
self.assertIsNone(dm_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient2(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = constant_op.constant([3, 4])
x = math_ops.cast(k, dtypes.float32)
tape.watch([k, x])
@test_util.run_deprecated_v1
def testNoIntegerGradient4(self):
k = constant_op.constant([3, 4])
m = k * k * k
dm_dk, = gradients_impl.gradients(m, k)
self.assertIsNone(dm_dk)
y = x * x
dy_dk = tape.gradient(y, k)
self.assertIsNone(dy_dk)
@test_util.run_deprecated_v1
def testNoIntegerGradient5(self):
k = constant_op.constant([3, 4])
m = k * k
n = m * m
dn_dk, = gradients_impl.gradients(n, k)
self.assertIsNone(dn_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient3(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = constant_op.constant([3, 4])
tape.watch(k)
@test_util.run_deprecated_v1
def testNoIntegerGradient6(self):
k = constant_op.constant(3)
x = math_ops.cast(k, dtypes.float32)
grad_1, = gradients_impl.gradients(k * k, k)
grad_2, = gradients_impl.gradients(x * x, k)
grad_3, = gradients_impl.gradients(math_ops.square(k), k)
grad_4, = gradients_impl.gradients(math_ops.square(x), k)
self.assertIsNone(grad_1)
self.assertIsNone(grad_2)
self.assertIsNone(grad_3)
self.assertIsNone(grad_4)
m = k * k
dm_dk = tape.gradient(m, k)
self.assertIsNone(dm_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient4(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = constant_op.constant([3, 4])
tape.watch(k)
m = k * k * k
dm_dk = tape.gradient(m, k)
self.assertIsNone(dm_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient5(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
k = constant_op.constant([3, 4])
tape.watch(k)
m = k * k
n = m * m
dn_dk = tape.gradient(n, k)
self.assertIsNone(dn_dk)
@parameterized.parameters(set((True, context.executing_eagerly())))
def testNoIntegerGradient6(self, use_tape):
with test_util.AbstractGradientTape(
use_tape=use_tape, persistent=True) as tape:
k = constant_op.constant(3)
tape.watch(k)
x = math_ops.cast(k, dtypes.float32)
grad_1 = tape.gradient(k * k, k)
grad_2 = tape.gradient(x * x, k)
grad_3 = tape.gradient(math_ops.square(k), k)
grad_4 = tape.gradient(math_ops.square(x), k)
self.assertIsNone(grad_1)
self.assertIsNone(grad_2)
self.assertIsNone(grad_3)
self.assertIsNone(grad_4)
if __name__ == '__main__':