Merge pull request #31409 from zaccharieramzi:flexible_fftshift

PiperOrigin-RevId: 266869497
This commit is contained in:
TensorFlower Gardener 2019-09-03 00:39:22 -07:00
commit 21a5c3f864
2 changed files with 23 additions and 6 deletions

View File

@ -606,5 +606,22 @@ class FFTShiftTest(test.TestCase):
np.fft.ifftshift(shifted, axes=(0, 1)))
@test_util.run_deprecated_v1
def testPlaceholder(self):
x = array_ops.placeholder(shape=[None, None, None], dtype="float32")
axes_to_test = [None, 1, [1, 2]]
for axes in axes_to_test:
y_fftshift = fft_ops.fftshift(x, axes=axes)
y_ifftshift = fft_ops.ifftshift(x, axes=axes)
with self.session() as sess:
x_np = np.random.rand(16, 256, 256)
y_fftshift_res, y_ifftshift_res = sess.run(
[y_fftshift, y_ifftshift],
feed_dict={x: x_np},
)
self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
if __name__ == "__main__":
test.main()

View File

@ -358,11 +358,11 @@ def fftshift(x, axes=None, name=None):
x = _ops.convert_to_tensor(x)
if axes is None:
axes = tuple(range(x.shape.ndims))
shift = [int(dim // 2) for dim in x.shape]
shift = _array_ops.shape(x) // 2
elif isinstance(axes, int):
shift = int(x.shape[axes] // 2)
shift = _array_ops.shape(x)[axes] // 2
else:
shift = [int((x.shape[ax]) // 2) for ax in axes]
shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
return manip_ops.roll(x, shift, axes, name)
@ -399,11 +399,11 @@ def ifftshift(x, axes=None, name=None):
x = _ops.convert_to_tensor(x)
if axes is None:
axes = tuple(range(x.shape.ndims))
shift = [-int(dim // 2) for dim in x.shape]
shift = -(_array_ops.shape(x) // 2)
elif isinstance(axes, int):
shift = -int(x.shape[axes] // 2)
shift = -(_array_ops.shape(x)[axes] // 2)
else:
shift = [-int(x.shape[ax] // 2) for ax in axes]
shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
return manip_ops.roll(x, shift, axes, name)