added numpy compatibility test
This commit is contained in:
parent
0aaf67a29b
commit
03be2ce9b4
@ -580,5 +580,21 @@ class FFTShiftTest(test.TestCase):
|
||||
self.assertTrue((fft_ops.fftshift(freqs).numpy() == shifted).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted).numpy() == freqs).all())
|
||||
|
||||
def testNumpyCompatibility(self):
|
||||
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
|
||||
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
|
||||
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
|
||||
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all())
|
||||
self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all())
|
||||
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
|
||||
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
|
||||
self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \
|
||||
np.fft.fftshift(freqs, axes=(0, 1))).all())
|
||||
self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \
|
||||
np.fft.ifftshift(shifted, axes=(0, 1))).all())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user