Fix fft_ops_test.py for CPU
This commit is contained in:
parent
88d648f3be
commit
50d80ddf92
@ -298,36 +298,37 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
|||||||
use_placeholder=True)
|
use_placeholder=True)
|
||||||
|
|
||||||
def testFftLength(self):
|
def testFftLength(self):
|
||||||
for rank in VALID_FFT_RANKS:
|
if test.is_gpu_available(cuda_only=True):
|
||||||
for dims in xrange(rank, rank + 3):
|
for rank in VALID_FFT_RANKS:
|
||||||
for size in (5, 6):
|
for dims in xrange(rank, rank + 3):
|
||||||
inner_dim = size // 2 + 1
|
for size in (5, 6):
|
||||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
inner_dim = size // 2 + 1
|
||||||
(size,) * dims)
|
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
(size,) * dims)
|
||||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||||
|
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||||
|
|
||||||
# Test truncation (FFT size < dimensions).
|
# Test truncation (FFT size < dimensions).
|
||||||
fft_length = (size - 2,) * rank
|
fft_length = (size - 2,) * rank
|
||||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||||
|
|
||||||
# Confirm it works with unknown shapes as well.
|
# Confirm it works with unknown shapes as well.
|
||||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||||
use_placeholder=True)
|
use_placeholder=True)
|
||||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||||
use_placeholder=True)
|
use_placeholder=True)
|
||||||
|
|
||||||
# Test padding (FFT size > dimensions).
|
# Test padding (FFT size > dimensions).
|
||||||
fft_length = (size + 2,) * rank
|
fft_length = (size + 2,) * rank
|
||||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
self._CompareForward(r2c.astype(np.float32), rank, fft_length)
|
||||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
|
||||||
|
|
||||||
# Confirm it works with unknown shapes as well.
|
# Confirm it works with unknown shapes as well.
|
||||||
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
self._CompareForward(r2c.astype(np.float32), rank, fft_length,
|
||||||
use_placeholder=True)
|
use_placeholder=True)
|
||||||
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
|
||||||
use_placeholder=True)
|
use_placeholder=True)
|
||||||
|
|
||||||
def testRandom(self):
|
def testRandom(self):
|
||||||
np.random.seed(12345)
|
np.random.seed(12345)
|
||||||
|
Loading…
Reference in New Issue
Block a user