Fix fft_ops_test.py for CPU
This commit is contained in:
parent
88d648f3be
commit
50d80ddf92
@ -298,36 +298,37 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
use_placeholder=True)
|
||||
|
||||
def testFftLength(self):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
|
||||
# Test truncation (FFT size < dimensions).
|
||||
fft_length = (size - 2,) * rank
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||
# Test truncation (FFT size < dimensions).
|
||||
fft_length = (size - 2,) * rank
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||
|
||||
# Confirm it works with unknown shapes as well.
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
# Confirm it works with unknown shapes as well.
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
|
||||
# Test padding (FFT size > dimensions).
|
||||
fft_length = (size + 2,) * rank
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||
# Test padding (FFT size > dimensions).
|
||||
fft_length = (size + 2,) * rank
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||
|
||||
# Confirm it works with unknown shapes as well.
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
# Confirm it works with unknown shapes as well.
|
||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||
use_placeholder=True)
|
||||
|
||||
def testRandom(self):
|
||||
np.random.seed(12345)
|
||||
|
Loading…
Reference in New Issue
Block a user