Remove @test_util.run_deprecated_v1 in gradient_correctness_test.py
Also add `persistent` arg to test_util.AbstractGradientTape PiperOrigin-RevId: 324306126 Change-Id: I4f30cff141d725b9c0dd9d2c71e353df4923ca2e
This commit is contained in:
parent
5594fde93b
commit
29b576a718
@ -3318,16 +3318,16 @@ class AbstractGradientTape:
|
||||
duplicating tests.
|
||||
"""
|
||||
|
||||
def __init__(self, use_tape):
|
||||
def __init__(self, use_tape, persistent=False):
|
||||
self._use_tape = use_tape
|
||||
self._persistent = persistent
|
||||
|
||||
def __enter__(self):
|
||||
if self._use_tape:
|
||||
self._tape_impl = backprop.GradientTape()
|
||||
self._tape_impl = backprop.GradientTape(persistent=self._persistent)
|
||||
else:
|
||||
self._tape_impl = _fake_gradient_tape_context_manager()
|
||||
return self._tape_impl.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
@ -25,20 +25,21 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleOutputChainedGradients(self):
|
||||
with self.cached_session() as sess:
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testMultipleOutputChainedGradients(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
x = constant_op.constant(1.0, dtype=dtypes.float32)
|
||||
tape.watch(x)
|
||||
|
||||
yexp = math_ops.exp(x)
|
||||
yexplog = math_ops.log(yexp)
|
||||
grads = gradients_impl.gradients([yexp, yexplog], [x])
|
||||
grads = tape.gradient([yexp, yexplog], [x])
|
||||
grad_vals = self.evaluate(grads)
|
||||
exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32)
|
||||
# [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
|
||||
@ -52,72 +53,94 @@ class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
|
||||
dx_dx = tape.gradient(x, x)
|
||||
self.assertAllClose(1., self.evaluate(dx_dx))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIntegerIdentityGradient(self):
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testIntegerIdentityGradient(self, use_tape):
|
||||
x = constant_op.constant(3)
|
||||
dx_dx, = gradients_impl.gradients(x, x)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllClose(1, self.evaluate(dx_dx))
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
tape.watch(x)
|
||||
dx_dx = tape.gradient(x, x)
|
||||
self.assertAllClose(1, self.evaluate(dx_dx))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientWithIntegerPath(self):
|
||||
x = constant_op.constant([3.9, 4.1])
|
||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||
y = x * k
|
||||
dy_dx, = gradients_impl.gradients(y, x)
|
||||
with self.cached_session() as sess:
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testGradientWithIntegerPath(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
x = constant_op.constant([3.9, 4.1])
|
||||
tape.watch(x)
|
||||
|
||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||
y = x * k
|
||||
dy_dx = tape.gradient(y, x)
|
||||
self.assertAllClose([3., 4.], self.evaluate(dy_dx))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient1(self):
|
||||
x = constant_op.constant([3.9, 4.1])
|
||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||
y = k * k
|
||||
dy_dx, = gradients_impl.gradients(y, x)
|
||||
self.assertIsNone(dy_dx)
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient1(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
x = constant_op.constant([3.9, 4.1])
|
||||
tape.watch(x)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient2(self):
|
||||
k = constant_op.constant([3, 4])
|
||||
x = math_ops.cast(k, dtypes.float32)
|
||||
y = x * x
|
||||
dy_dk, = gradients_impl.gradients(y, k)
|
||||
self.assertIsNone(dy_dk)
|
||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||
y = k * k
|
||||
dy_dx = tape.gradient(y, x)
|
||||
self.assertIsNone(dy_dx)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient3(self):
|
||||
k = constant_op.constant([3, 4])
|
||||
m = k * k
|
||||
dm_dk, = gradients_impl.gradients(m, k)
|
||||
self.assertIsNone(dm_dk)
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient2(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
k = constant_op.constant([3, 4])
|
||||
x = math_ops.cast(k, dtypes.float32)
|
||||
tape.watch([k, x])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient4(self):
|
||||
k = constant_op.constant([3, 4])
|
||||
m = k * k * k
|
||||
dm_dk, = gradients_impl.gradients(m, k)
|
||||
self.assertIsNone(dm_dk)
|
||||
y = x * x
|
||||
dy_dk = tape.gradient(y, k)
|
||||
self.assertIsNone(dy_dk)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient5(self):
|
||||
k = constant_op.constant([3, 4])
|
||||
m = k * k
|
||||
n = m * m
|
||||
dn_dk, = gradients_impl.gradients(n, k)
|
||||
self.assertIsNone(dn_dk)
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient3(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
k = constant_op.constant([3, 4])
|
||||
tape.watch(k)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoIntegerGradient6(self):
|
||||
k = constant_op.constant(3)
|
||||
x = math_ops.cast(k, dtypes.float32)
|
||||
grad_1, = gradients_impl.gradients(k * k, k)
|
||||
grad_2, = gradients_impl.gradients(x * x, k)
|
||||
grad_3, = gradients_impl.gradients(math_ops.square(k), k)
|
||||
grad_4, = gradients_impl.gradients(math_ops.square(x), k)
|
||||
self.assertIsNone(grad_1)
|
||||
self.assertIsNone(grad_2)
|
||||
self.assertIsNone(grad_3)
|
||||
self.assertIsNone(grad_4)
|
||||
m = k * k
|
||||
dm_dk = tape.gradient(m, k)
|
||||
self.assertIsNone(dm_dk)
|
||||
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient4(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
k = constant_op.constant([3, 4])
|
||||
tape.watch(k)
|
||||
|
||||
m = k * k * k
|
||||
dm_dk = tape.gradient(m, k)
|
||||
self.assertIsNone(dm_dk)
|
||||
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient5(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
k = constant_op.constant([3, 4])
|
||||
tape.watch(k)
|
||||
|
||||
m = k * k
|
||||
n = m * m
|
||||
dn_dk = tape.gradient(n, k)
|
||||
self.assertIsNone(dn_dk)
|
||||
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testNoIntegerGradient6(self, use_tape):
|
||||
with test_util.AbstractGradientTape(
|
||||
use_tape=use_tape, persistent=True) as tape:
|
||||
k = constant_op.constant(3)
|
||||
tape.watch(k)
|
||||
|
||||
x = math_ops.cast(k, dtypes.float32)
|
||||
grad_1 = tape.gradient(k * k, k)
|
||||
grad_2 = tape.gradient(x * x, k)
|
||||
grad_3 = tape.gradient(math_ops.square(k), k)
|
||||
grad_4 = tape.gradient(math_ops.square(x), k)
|
||||
self.assertIsNone(grad_1)
|
||||
self.assertIsNone(grad_2)
|
||||
self.assertIsNone(grad_3)
|
||||
self.assertIsNone(grad_4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user