Remove @test_util.run_deprecated_v1 in gradient_correctness_test.py
Also add `persistent` arg to test_util.AbstractGradientTape PiperOrigin-RevId: 324306126 Change-Id: I4f30cff141d725b9c0dd9d2c71e353df4923ca2e
This commit is contained in:
parent
5594fde93b
commit
29b576a718
@ -3318,16 +3318,16 @@ class AbstractGradientTape:
|
|||||||
duplicating tests.
|
duplicating tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, use_tape):
|
def __init__(self, use_tape, persistent=False):
|
||||||
self._use_tape = use_tape
|
self._use_tape = use_tape
|
||||||
|
self._persistent = persistent
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if self._use_tape:
|
if self._use_tape:
|
||||||
self._tape_impl = backprop.GradientTape()
|
self._tape_impl = backprop.GradientTape(persistent=self._persistent)
|
||||||
else:
|
else:
|
||||||
self._tape_impl = _fake_gradient_tape_context_manager()
|
self._tape_impl = _fake_gradient_tape_context_manager()
|
||||||
return self._tape_impl.__enter__()
|
return self._tape_impl.__enter__()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
|
self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
@ -25,20 +25,21 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gradients_impl
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
|
class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testMultipleOutputChainedGradients(self):
|
def testMultipleOutputChainedGradients(self, use_tape):
|
||||||
with self.cached_session() as sess:
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
x = constant_op.constant(1.0, dtype=dtypes.float32)
|
x = constant_op.constant(1.0, dtype=dtypes.float32)
|
||||||
|
tape.watch(x)
|
||||||
|
|
||||||
yexp = math_ops.exp(x)
|
yexp = math_ops.exp(x)
|
||||||
yexplog = math_ops.log(yexp)
|
yexplog = math_ops.log(yexp)
|
||||||
grads = gradients_impl.gradients([yexp, yexplog], [x])
|
grads = tape.gradient([yexp, yexplog], [x])
|
||||||
grad_vals = self.evaluate(grads)
|
grad_vals = self.evaluate(grads)
|
||||||
exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32)
|
exp1_plus_one = (1.0 + np.exp(1.0)).astype(np.float32)
|
||||||
# [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
|
# [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
|
||||||
@ -52,72 +53,94 @@ class GradientCorrectnessTest(test.TestCase, parameterized.TestCase):
|
|||||||
dx_dx = tape.gradient(x, x)
|
dx_dx = tape.gradient(x, x)
|
||||||
self.assertAllClose(1., self.evaluate(dx_dx))
|
self.assertAllClose(1., self.evaluate(dx_dx))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testIntegerIdentityGradient(self):
|
def testIntegerIdentityGradient(self, use_tape):
|
||||||
x = constant_op.constant(3)
|
x = constant_op.constant(3)
|
||||||
dx_dx, = gradients_impl.gradients(x, x)
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
with self.cached_session() as sess:
|
tape.watch(x)
|
||||||
self.assertAllClose(1, self.evaluate(dx_dx))
|
dx_dx = tape.gradient(x, x)
|
||||||
|
self.assertAllClose(1, self.evaluate(dx_dx))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testGradientWithIntegerPath(self):
|
def testGradientWithIntegerPath(self, use_tape):
|
||||||
x = constant_op.constant([3.9, 4.1])
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
x = constant_op.constant([3.9, 4.1])
|
||||||
y = x * k
|
tape.watch(x)
|
||||||
dy_dx, = gradients_impl.gradients(y, x)
|
|
||||||
with self.cached_session() as sess:
|
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||||
|
y = x * k
|
||||||
|
dy_dx = tape.gradient(y, x)
|
||||||
self.assertAllClose([3., 4.], self.evaluate(dy_dx))
|
self.assertAllClose([3., 4.], self.evaluate(dy_dx))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testNoIntegerGradient1(self):
|
def testNoIntegerGradient1(self, use_tape):
|
||||||
x = constant_op.constant([3.9, 4.1])
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
x = constant_op.constant([3.9, 4.1])
|
||||||
y = k * k
|
tape.watch(x)
|
||||||
dy_dx, = gradients_impl.gradients(y, x)
|
|
||||||
self.assertIsNone(dy_dx)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
k = math_ops.cast(math_ops.cast(x, dtypes.int32), dtypes.float32)
|
||||||
def testNoIntegerGradient2(self):
|
y = k * k
|
||||||
k = constant_op.constant([3, 4])
|
dy_dx = tape.gradient(y, x)
|
||||||
x = math_ops.cast(k, dtypes.float32)
|
self.assertIsNone(dy_dx)
|
||||||
y = x * x
|
|
||||||
dy_dk, = gradients_impl.gradients(y, k)
|
|
||||||
self.assertIsNone(dy_dk)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testNoIntegerGradient3(self):
|
def testNoIntegerGradient2(self, use_tape):
|
||||||
k = constant_op.constant([3, 4])
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
m = k * k
|
k = constant_op.constant([3, 4])
|
||||||
dm_dk, = gradients_impl.gradients(m, k)
|
x = math_ops.cast(k, dtypes.float32)
|
||||||
self.assertIsNone(dm_dk)
|
tape.watch([k, x])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
y = x * x
|
||||||
def testNoIntegerGradient4(self):
|
dy_dk = tape.gradient(y, k)
|
||||||
k = constant_op.constant([3, 4])
|
self.assertIsNone(dy_dk)
|
||||||
m = k * k * k
|
|
||||||
dm_dk, = gradients_impl.gradients(m, k)
|
|
||||||
self.assertIsNone(dm_dk)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testNoIntegerGradient5(self):
|
def testNoIntegerGradient3(self, use_tape):
|
||||||
k = constant_op.constant([3, 4])
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
m = k * k
|
k = constant_op.constant([3, 4])
|
||||||
n = m * m
|
tape.watch(k)
|
||||||
dn_dk, = gradients_impl.gradients(n, k)
|
|
||||||
self.assertIsNone(dn_dk)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
m = k * k
|
||||||
def testNoIntegerGradient6(self):
|
dm_dk = tape.gradient(m, k)
|
||||||
k = constant_op.constant(3)
|
self.assertIsNone(dm_dk)
|
||||||
x = math_ops.cast(k, dtypes.float32)
|
|
||||||
grad_1, = gradients_impl.gradients(k * k, k)
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
grad_2, = gradients_impl.gradients(x * x, k)
|
def testNoIntegerGradient4(self, use_tape):
|
||||||
grad_3, = gradients_impl.gradients(math_ops.square(k), k)
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
grad_4, = gradients_impl.gradients(math_ops.square(x), k)
|
k = constant_op.constant([3, 4])
|
||||||
self.assertIsNone(grad_1)
|
tape.watch(k)
|
||||||
self.assertIsNone(grad_2)
|
|
||||||
self.assertIsNone(grad_3)
|
m = k * k * k
|
||||||
self.assertIsNone(grad_4)
|
dm_dk = tape.gradient(m, k)
|
||||||
|
self.assertIsNone(dm_dk)
|
||||||
|
|
||||||
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
|
def testNoIntegerGradient5(self, use_tape):
|
||||||
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
|
k = constant_op.constant([3, 4])
|
||||||
|
tape.watch(k)
|
||||||
|
|
||||||
|
m = k * k
|
||||||
|
n = m * m
|
||||||
|
dn_dk = tape.gradient(n, k)
|
||||||
|
self.assertIsNone(dn_dk)
|
||||||
|
|
||||||
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
|
def testNoIntegerGradient6(self, use_tape):
|
||||||
|
with test_util.AbstractGradientTape(
|
||||||
|
use_tape=use_tape, persistent=True) as tape:
|
||||||
|
k = constant_op.constant(3)
|
||||||
|
tape.watch(k)
|
||||||
|
|
||||||
|
x = math_ops.cast(k, dtypes.float32)
|
||||||
|
grad_1 = tape.gradient(k * k, k)
|
||||||
|
grad_2 = tape.gradient(x * x, k)
|
||||||
|
grad_3 = tape.gradient(math_ops.square(k), k)
|
||||||
|
grad_4 = tape.gradient(math_ops.square(x), k)
|
||||||
|
self.assertIsNone(grad_1)
|
||||||
|
self.assertIsNone(grad_2)
|
||||||
|
self.assertIsNone(grad_3)
|
||||||
|
self.assertIsNone(grad_4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user