added assertAllEqual
This commit is contained in:
parent
03be2ce9b4
commit
82cda38a73
@ -559,42 +559,40 @@ class FFTShiftTest(test.TestCase):
|
||||
def testDefinition(self):
|
||||
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
|
||||
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == y).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(x).numpy(), y)
|
||||
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x)
|
||||
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
|
||||
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == y).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(x).numpy(), y)
|
||||
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x)
|
||||
|
||||
def testAxesKeyword(self):
|
||||
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
|
||||
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
|
||||
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
|
||||
shifted).all())
|
||||
self.assertTrue((fft_ops.fftshift(freqs, axes=0).numpy() == \
|
||||
fft_ops.fftshift(freqs, axes=(0,)).numpy()).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
|
||||
freqs).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted, axes=0).numpy() == \
|
||||
fft_ops.ifftshift(shifted, axes=(0,)).numpy()).all())
|
||||
self.assertTrue((fft_ops.fftshift(freqs).numpy() == shifted).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted).numpy() == freqs).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(), shifted)
|
||||
self.assertAllEqual(fft_ops.fftshift(freqs, axes=0).numpy(),
|
||||
fft_ops.fftshift(freqs, axes=(0,)).numpy())
|
||||
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(), freqs)
|
||||
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=0).numpy(),
|
||||
fft_ops.ifftshift(shifted, axes=(0,)).numpy())
|
||||
self.assertAllEqual(fft_ops.fftshift(freqs).numpy(), shifted)
|
||||
self.assertAllEqual(fft_ops.ifftshift(shifted).numpy(), freqs)
|
||||
|
||||
def testNumpyCompatibility(self):
|
||||
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
|
||||
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x))
|
||||
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y))
|
||||
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
|
||||
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x))
|
||||
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y))
|
||||
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
|
||||
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
|
||||
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
|
||||
np.fft.fftshift(freqs, axes=(0, 1))).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
|
||||
np.fft.ifftshift(shifted, axes=(0, 1))).all())
|
||||
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(),
|
||||
np.fft.fftshift(freqs, axes=(0, 1)))
|
||||
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(),
|
||||
np.fft.ifftshift(shifted, axes=(0, 1)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -40,19 +40,21 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops.signal.dct_ops import dct
|
||||
from tensorflow.python.ops.signal.dct_ops import idct
|
||||
from tensorflow.python.ops.signal.fft_ops import fft
|
||||
from tensorflow.python.ops.signal.fft_ops import fft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import fft3d
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft3d
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft3d
|
||||
from tensorflow.python.ops.signal.fft_ops import fftshift
|
||||
from tensorflow.python.ops.signal.fft_ops import rfft
|
||||
from tensorflow.python.ops.signal.fft_ops import rfft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import rfft3d
|
||||
from tensorflow.python.ops.signal.dct_ops import idct
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import ifft3d
|
||||
from tensorflow.python.ops.signal.fft_ops import ifftshift
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft2d
|
||||
from tensorflow.python.ops.signal.fft_ops import irfft3d
|
||||
from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix
|
||||
from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms
|
||||
from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add
|
||||
@ -62,6 +64,4 @@ from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn
|
||||
from tensorflow.python.ops.signal.spectral_ops import stft
|
||||
from tensorflow.python.ops.signal.window_ops import hamming_window
|
||||
from tensorflow.python.ops.signal.window_ops import hann_window
|
||||
from tensorflow.python.ops.signal.fft_ops import fftshift
|
||||
from tensorflow.python.ops.signal.fft_ops import ifftshift
|
||||
# pylint: enable=unused-import
|
||||
|
Loading…
x
Reference in New Issue
Block a user