Remove @test_util.run_deprecated_v1 in reverse_sequence_op_test.py
PiperOrigin-RevId: 324127808 Change-Id: I64b37fca291e9400f9655bf49e5ba2fa62b5a2ad
This commit is contained in:
parent
ee156c5d85
commit
fafe48d816
@ -24,9 +24,8 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -109,7 +108,6 @@ class ReverseSequenceTest(test.TestCase):
|
||||
def testComplex128Basic(self):
|
||||
self._testBasic(np.complex128)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloatReverseSequenceGrad(self):
|
||||
x = np.asarray(
|
||||
[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
|
||||
@ -123,18 +121,18 @@ class ReverseSequenceTest(test.TestCase):
|
||||
batch_axis = 2
|
||||
seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
|
||||
|
||||
with self.cached_session():
|
||||
input_t = constant_op.constant(x, shape=x.shape)
|
||||
def reverse_sequence(x):
|
||||
seq_lengths_t = constant_op.constant(seq_lengths, shape=seq_lengths.shape)
|
||||
reverse_sequence_out = array_ops.reverse_sequence(
|
||||
input_t,
|
||||
return array_ops.reverse_sequence(
|
||||
x,
|
||||
batch_axis=batch_axis,
|
||||
seq_axis=seq_axis,
|
||||
seq_lengths=seq_lengths_t)
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
input_t, x.shape, reverse_sequence_out, x.shape, x_init_value=x)
|
||||
print("ReverseSequence gradient error = %g" % err)
|
||||
self.assertLess(err, 1e-8)
|
||||
|
||||
with self.cached_session():
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(reverse_sequence, [x]))
|
||||
self.assertLess(err, 1e-8)
|
||||
|
||||
def testShapeFunctionEdgeCases(self):
|
||||
# Enter graph mode since we want to test partial shapes
|
||||
|
Loading…
Reference in New Issue
Block a user