Make tests eager compatabile by removing eval()
PiperOrigin-RevId: 236194612
This commit is contained in:
parent
62f1d7fd04
commit
b71a828843
@ -26,9 +26,8 @@ class AddOneTest(tf.test.TestCase):
|
||||
|
||||
def test(self):
|
||||
if tf.test.is_built_with_cuda():
|
||||
with self.cached_session():
|
||||
result = cuda_op.add_one([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
|
||||
result = cuda_op.add_one([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result, [6, 5, 4, 3, 2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -110,13 +110,13 @@ class FFTOpsTest(BaseFFTOpsTest):
|
||||
|
||||
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
|
||||
# fft_length unused for complex FFTs.
|
||||
with self.cached_session(use_gpu=True):
|
||||
return self._tfFFTForRank(rank)(x).eval(feed_dict=feed_dict)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(self._tfFFTForRank(rank)(x), feed_dict=feed_dict)
|
||||
|
||||
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
|
||||
# fft_length unused for complex FFTs.
|
||||
with self.cached_session(use_gpu=True):
|
||||
return self._tfIFFTForRank(rank)(x).eval(feed_dict=feed_dict)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(self._tfIFFTForRank(rank)(x), feed_dict=feed_dict)
|
||||
|
||||
def _npFFT(self, x, rank, fft_length=None):
|
||||
if rank == 1:
|
||||
@ -292,12 +292,14 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
use_placeholder)
|
||||
|
||||
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with self.cached_session(use_gpu=True):
|
||||
return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tfFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with self.cached_session(use_gpu=True):
|
||||
return self._tfIFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tfIFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
def _npFFT(self, x, rank, fft_length=None):
|
||||
if rank == 1:
|
||||
@ -512,7 +514,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
x = np.zeros((5,) * rank).astype(np.float32)
|
||||
fft_length = [6] * rank
|
||||
with self.cached_session():
|
||||
rfft_fn(x, fft_length).eval()
|
||||
self.evaluate(rfft_fn(x, fft_length))
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors.InvalidArgumentError,
|
||||
@ -520,7 +522,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
x = np.zeros((3,) * rank).astype(np.complex64)
|
||||
fft_length = [6] * rank
|
||||
with self.cached_session():
|
||||
irfft_fn(x, fft_length).eval()
|
||||
self.evaluate(irfft_fn(x, fft_length))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGrad_Simple(self):
|
||||
|
Loading…
Reference in New Issue
Block a user