Merge pull request #38209 from yongtang:38172-fftshift-negative-axis
PiperOrigin-RevId: 305264576 Change-Id: Ie0e53717cd50fda5cb2875836be53547867b2d13
This commit is contained in:
commit
aee65d8f26
@ -662,6 +662,18 @@ class FFTShiftTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
|
||||
self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
|
||||
|
||||
def test_negative_axes(self):
|
||||
with self.session():
|
||||
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
|
||||
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
|
||||
self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted)
|
||||
self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs)
|
||||
self.assertAllEqual(
|
||||
fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,)))
|
||||
self.assertAllEqual(
|
||||
fft_ops.ifftshift(shifted, axes=-1),
|
||||
fft_ops.ifftshift(shifted, axes=(1,)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -398,6 +398,9 @@ def fftshift(x, axes=None, name=None):
|
||||
elif isinstance(axes, int):
|
||||
shift = _array_ops.shape(x)[axes] // 2
|
||||
else:
|
||||
rank = _array_ops.rank(x)
|
||||
# allows negative axis
|
||||
axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
|
||||
shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
|
||||
|
||||
return manip_ops.roll(x, shift, axes, name)
|
||||
@ -439,6 +442,9 @@ def ifftshift(x, axes=None, name=None):
|
||||
elif isinstance(axes, int):
|
||||
shift = -(_array_ops.shape(x)[axes] // 2)
|
||||
else:
|
||||
rank = _array_ops.rank(x)
|
||||
# allows negative axis
|
||||
axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
|
||||
shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
|
||||
|
||||
return manip_ops.roll(x, shift, axes, name)
|
||||
|
Loading…
Reference in New Issue
Block a user