Make tests eager compatabile by removing eval()

PiperOrigin-RevId: 236194612
This commit is contained in:
Gaurav Jain 2019-02-28 14:09:49 -08:00 committed by TensorFlower Gardener
parent 62f1d7fd04
commit b71a828843
2 changed files with 14 additions and 13 deletions

View File

@ -26,9 +26,8 @@ class AddOneTest(tf.test.TestCase):
def test(self):
if tf.test.is_built_with_cuda():
with self.cached_session():
result = cuda_op.add_one([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
result = cuda_op.add_one([5, 4, 3, 2, 1])
self.assertAllEqual(result, [6, 5, 4, 3, 2])
if __name__ == '__main__':

View File

@ -110,13 +110,13 @@ class FFTOpsTest(BaseFFTOpsTest):
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
with self.cached_session(use_gpu=True):
return self._tfFFTForRank(rank)(x).eval(feed_dict=feed_dict)
with self.cached_session(use_gpu=True) as sess:
return sess.run(self._tfFFTForRank(rank)(x), feed_dict=feed_dict)
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
# fft_length unused for complex FFTs.
with self.cached_session(use_gpu=True):
return self._tfIFFTForRank(rank)(x).eval(feed_dict=feed_dict)
with self.cached_session(use_gpu=True) as sess:
return sess.run(self._tfIFFTForRank(rank)(x), feed_dict=feed_dict)
def _npFFT(self, x, rank, fft_length=None):
if rank == 1:
@ -292,12 +292,14 @@ class RFFTOpsTest(BaseFFTOpsTest):
use_placeholder)
def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True):
return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tfFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True):
return self._tfIFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tfIFFTForRank(rank)(x, fft_length), feed_dict=feed_dict)
def _npFFT(self, x, rank, fft_length=None):
if rank == 1:
@ -512,7 +514,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
x = np.zeros((5,) * rank).astype(np.float32)
fft_length = [6] * rank
with self.cached_session():
rfft_fn(x, fft_length).eval()
self.evaluate(rfft_fn(x, fft_length))
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
@ -520,7 +522,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
x = np.zeros((3,) * rank).astype(np.complex64)
fft_length = [6] * rank
with self.cached_session():
irfft_fn(x, fft_length).eval()
self.evaluate(irfft_fn(x, fft_length))
@test_util.run_deprecated_v1
def testGrad_Simple(self):