added numpy compatibility test

This commit is contained in:
gurpreet singh 2019-03-27 16:30:29 +05:30
parent 0aaf67a29b
commit 03be2ce9b4

View File

@ -580,5 +580,21 @@ class FFTShiftTest(test.TestCase):
self.assertTrue((fft_ops.fftshift(freqs).numpy() == shifted).all())
self.assertTrue((fft_ops.ifftshift(shifted).numpy() == freqs).all())
def testNumpyCompatibility(self):
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
np.fft.fftshift(freqs, axes=(0, 1))).all())
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
np.fft.ifftshift(shifted, axes=(0, 1))).all())
if __name__ == "__main__":
test.main()