Merge pull request #38209 from yongtang:38172-fftshift-negative-axis

PiperOrigin-RevId: 305264576
Change-Id: Ie0e53717cd50fda5cb2875836be53547867b2d13
This commit is contained in:
TensorFlower Gardener 2020-04-07 08:42:14 -07:00
commit aee65d8f26
2 changed files with 18 additions and 0 deletions

View File

@ -662,6 +662,18 @@ class FFTShiftTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
def test_negative_axes(self):
with self.session():
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted)
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs)
self.assertAllEqual(
fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,)))
self.assertAllEqual(
fft_ops.ifftshift(shifted, axes=-1),
fft_ops.ifftshift(shifted, axes=(1,)))
if __name__ == "__main__":
test.main()

View File

@ -398,6 +398,9 @@ def fftshift(x, axes=None, name=None):
elif isinstance(axes, int):
shift = _array_ops.shape(x)[axes] // 2
else:
rank = _array_ops.rank(x)
# allows negative axis
axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
return manip_ops.roll(x, shift, axes, name)
@ -439,6 +442,9 @@ def ifftshift(x, axes=None, name=None):
elif isinstance(axes, int):
shift = -(_array_ops.shape(x)[axes] // 2)
else:
rank = _array_ops.rank(x)
# allows negative axis
axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
return manip_ops.roll(x, shift, axes, name)