From 82cda38a73f87c08d8f680a7226df0f1f1dabd4e Mon Sep 17 00:00:00 2001 From: gurpreet singh Date: Sat, 30 Mar 2019 10:52:40 +0530 Subject: [PATCH] added assertAllEqual --- .../kernel_tests/signal/fft_ops_test.py | 42 +++++++++---------- tensorflow/python/ops/signal/signal.py | 18 ++++---- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/kernel_tests/signal/fft_ops_test.py b/tensorflow/python/kernel_tests/signal/fft_ops_test.py index f5f3fe9dc62..3683a558de3 100644 --- a/tensorflow/python/kernel_tests/signal/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/fft_ops_test.py @@ -559,42 +559,40 @@ class FFTShiftTest(test.TestCase): def testDefinition(self): x = [0, 1, 2, 3, 4, -4, -3, -2, -1] y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] - self.assertTrue((fft_ops.fftshift(x).numpy() == y).all()) - self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all()) + self.assertAllEqual(fft_ops.fftshift(x).numpy(), y) + self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x) x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] - self.assertTrue((fft_ops.fftshift(x).numpy() == y).all()) - self.assertTrue((fft_ops.ifftshift(y).numpy() == x).all()) + self.assertAllEqual(fft_ops.fftshift(x).numpy(), y) + self.assertAllEqual(fft_ops.ifftshift(y).numpy(), x) def testAxesKeyword(self): freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] - self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \ - shifted).all()) - self.assertTrue((fft_ops.fftshift(freqs, axes=0).numpy() == \ - fft_ops.fftshift(freqs, axes=(0,)).numpy()).all()) - self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \ - freqs).all()) - self.assertTrue((fft_ops.ifftshift(shifted, axes=0).numpy() == \ - fft_ops.ifftshift(shifted, axes=(0,)).numpy()).all()) - self.assertTrue((fft_ops.fftshift(freqs).numpy() == shifted).all()) - self.assertTrue((fft_ops.ifftshift(shifted).numpy() == freqs).all()) + self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(), shifted) + self.assertAllEqual(fft_ops.fftshift(freqs, axes=0).numpy(), + fft_ops.fftshift(freqs, axes=(0,)).numpy()) + self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(), freqs) + self.assertAllEqual(fft_ops.ifftshift(shifted, axes=0).numpy(), + fft_ops.ifftshift(shifted, axes=(0,)).numpy()) + self.assertAllEqual(fft_ops.fftshift(freqs).numpy(), shifted) + self.assertAllEqual(fft_ops.ifftshift(shifted).numpy(), freqs) def testNumpyCompatibility(self): x = [0, 1, 2, 3, 4, -4, -3, -2, -1] y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] - self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all()) - self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all()) + self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x)) + self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y)) x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] - self.assertTrue((fft_ops.fftshift(x).numpy() == np.fft.fftshift(x)).all()) - self.assertTrue((fft_ops.ifftshift(y).numpy() == np.fft.ifftshift(y)).all()) + self.assertAllEqual(fft_ops.fftshift(x).numpy(), np.fft.fftshift(x)) + self.assertAllEqual(fft_ops.ifftshift(y).numpy(), np.fft.ifftshift(y)) freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] - self.assertTrue((fft_ops.fftshift(freqs, axes=(0, 1)).numpy() == \ - np.fft.fftshift(freqs, axes=(0, 1))).all()) - self.assertTrue((fft_ops.ifftshift(shifted, axes=(0, 1)).numpy() == \ - np.fft.ifftshift(shifted, axes=(0, 1))).all()) + self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)).numpy(), + np.fft.fftshift(freqs, axes=(0, 1))) + self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)).numpy(), + np.fft.ifftshift(shifted, axes=(0, 1))) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/signal/signal.py b/tensorflow/python/ops/signal/signal.py index 8b1c8381f95..f17a9fc0b3e 100644 --- a/tensorflow/python/ops/signal/signal.py +++ b/tensorflow/python/ops/signal/signal.py @@ -40,19 +40,21 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.python.ops.signal.dct_ops import dct -from tensorflow.python.ops.signal.dct_ops import idct from tensorflow.python.ops.signal.fft_ops import fft from tensorflow.python.ops.signal.fft_ops import fft2d from tensorflow.python.ops.signal.fft_ops import fft3d -from tensorflow.python.ops.signal.fft_ops import ifft -from tensorflow.python.ops.signal.fft_ops import ifft2d -from tensorflow.python.ops.signal.fft_ops import ifft3d -from tensorflow.python.ops.signal.fft_ops import irfft -from tensorflow.python.ops.signal.fft_ops import irfft2d -from tensorflow.python.ops.signal.fft_ops import irfft3d +from tensorflow.python.ops.signal.fft_ops import fftshift from tensorflow.python.ops.signal.fft_ops import rfft from tensorflow.python.ops.signal.fft_ops import rfft2d from tensorflow.python.ops.signal.fft_ops import rfft3d +from tensorflow.python.ops.signal.dct_ops import idct +from tensorflow.python.ops.signal.fft_ops import ifft +from tensorflow.python.ops.signal.fft_ops import ifft2d +from tensorflow.python.ops.signal.fft_ops import ifft3d +from tensorflow.python.ops.signal.fft_ops import ifftshift +from tensorflow.python.ops.signal.fft_ops import irfft +from tensorflow.python.ops.signal.fft_ops import irfft2d +from tensorflow.python.ops.signal.fft_ops import irfft3d from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add @@ -62,6 +64,4 @@ from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn from tensorflow.python.ops.signal.spectral_ops import stft from tensorflow.python.ops.signal.window_ops import hamming_window from tensorflow.python.ops.signal.window_ops import hann_window -from tensorflow.python.ops.signal.fft_ops import fftshift -from tensorflow.python.ops.signal.fft_ops import ifftshift # pylint: enable=unused-import