added assertAllEqual

This commit is contained in:
gurpreet singh 2019-03-30 10:52:40 +05:30
parent 03be2ce9b4
commit 82cda38a73
2 changed files with 29 additions and 31 deletions

View File

@ -559,42 +559,40 @@ class FFTShiftTest(test.TestCase):
def testDefinition(self):
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == y).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all())
self.assertAllEqual(fft_ops.fftshift(x).numpy(), y)
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x)
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == y).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all())
self.assertAllEqual(fft_ops.fftshift(x).numpy(), y)
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x)
def testAxesKeyword(self):
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
shifted).all())
self.assertTrue((fft_ops.fftshift(freqs, axes=0).numpy() == \
fft_ops.fftshift(freqs, axes=(0,)).numpy()).all())
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
freqs).all())
self.assertTrue((fft_ops.ifftshift(shifted, axes=0).numpy() == \
fft_ops.ifftshift(shifted, axes=(0,)).numpy()).all())
self.assertTrue((fft_ops.fftshift(freqs).numpy() == shifted).all())
self.assertTrue((fft_ops.ifftshift(shifted).numpy() == freqs).all())
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(), shifted)
self.assertAllEqual(fft_ops.fftshift(freqs, axes=0).numpy(),
fft_ops.fftshift(freqs, axes=(0,)).numpy())
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(), freqs)
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=0).numpy(),
fft_ops.ifftshift(shifted, axes=(0,)).numpy())
self.assertAllEqual(fft_ops.fftshift(freqs).numpy(), shifted)
self.assertAllEqual(fft_ops.ifftshift(shifted).numpy(), freqs)
def testNumpyCompatibility(self):
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x))
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y))
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x))
self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y))
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
np.fft.fftshift(freqs, axes=(0, 1))).all())
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
np.fft.ifftshift(shifted, axes=(0, 1))).all())
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(),
np.fft.fftshift(freqs, axes=(0, 1)))
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(),
np.fft.ifftshift(shifted, axes=(0, 1)))
if __name__ == "__main__":
test.main()

View File

@ -40,19 +40,21 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.ops.signal.dct_ops import dct
from tensorflow.python.ops.signal.dct_ops import idct
from tensorflow.python.ops.signal.fft_ops import fft
from tensorflow.python.ops.signal.fft_ops import fft2d
from tensorflow.python.ops.signal.fft_ops import fft3d
from tensorflow.python.ops.signal.fft_ops import ifft
from tensorflow.python.ops.signal.fft_ops import ifft2d
from tensorflow.python.ops.signal.fft_ops import ifft3d
from tensorflow.python.ops.signal.fft_ops import irfft
from tensorflow.python.ops.signal.fft_ops import irfft2d
from tensorflow.python.ops.signal.fft_ops import irfft3d
from tensorflow.python.ops.signal.fft_ops import fftshift
from tensorflow.python.ops.signal.fft_ops import rfft
from tensorflow.python.ops.signal.fft_ops import rfft2d
from tensorflow.python.ops.signal.fft_ops import rfft3d
from tensorflow.python.ops.signal.dct_ops import idct
from tensorflow.python.ops.signal.fft_ops import ifft
from tensorflow.python.ops.signal.fft_ops import ifft2d
from tensorflow.python.ops.signal.fft_ops import ifft3d
from tensorflow.python.ops.signal.fft_ops import ifftshift
from tensorflow.python.ops.signal.fft_ops import irfft
from tensorflow.python.ops.signal.fft_ops import irfft2d
from tensorflow.python.ops.signal.fft_ops import irfft3d
from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix
from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms
from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add
@ -62,6 +64,4 @@ from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn
from tensorflow.python.ops.signal.spectral_ops import stft
from tensorflow.python.ops.signal.window_ops import hamming_window
from tensorflow.python.ops.signal.window_ops import hann_window
from tensorflow.python.ops.signal.fft_ops import fftshift
from tensorflow.python.ops.signal.fft_ops import ifftshift
# pylint: enable=unused-import