Fix fft_ops_test.py for CPU

This commit is contained in:
Jonathan Hseu 2017-06-08 11:58:21 -07:00
parent 88d648f3be
commit 50d80ddf92

View File

@ -298,36 +298,37 @@ class RFFTOpsTest(BaseFFTOpsTest):
use_placeholder=True)
def testFftLength(self):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
if test.is_gpu_available(cuda_only=True):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
# Test truncation (FFT size < dimensions).
fft_length = (size - 2,) * rank
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
# Test truncation (FFT size < dimensions).
fft_length = (size - 2,) * rank
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
# Confirm it works with unknown shapes as well.
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
use_placeholder=True)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
use_placeholder=True)
# Confirm it works with unknown shapes as well.
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
use_placeholder=True)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
use_placeholder=True)
# Test padding (FFT size > dimensions).
fft_length = (size + 2,) * rank
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
# Test padding (FFT size > dimensions).
fft_length = (size + 2,) * rank
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
# Confirm it works with unknown shapes as well.
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
use_placeholder=True)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
use_placeholder=True)
# Confirm it works with unknown shapes as well.
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
use_placeholder=True)
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
use_placeholder=True)
def testRandom(self):
np.random.seed(12345)