Merge pull request #31409 from zaccharieramzi:flexible_fftshift
PiperOrigin-RevId: 266869497
This commit is contained in:
commit
21a5c3f864
@ -606,5 +606,22 @@ class FFTShiftTest(test.TestCase):
|
||||
np.fft.ifftshift(shifted, axes=(0, 1)))
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testPlaceholder(self):
|
||||
x = array_ops.placeholder(shape=[None, None, None], dtype="float32")
|
||||
axes_to_test = [None, 1, [1, 2]]
|
||||
for axes in axes_to_test:
|
||||
y_fftshift = fft_ops.fftshift(x, axes=axes)
|
||||
y_ifftshift = fft_ops.ifftshift(x, axes=axes)
|
||||
with self.session() as sess:
|
||||
x_np = np.random.rand(16, 256, 256)
|
||||
y_fftshift_res, y_ifftshift_res = sess.run(
|
||||
[y_fftshift, y_ifftshift],
|
||||
feed_dict={x: x_np},
|
||||
)
|
||||
self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
|
||||
self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -358,11 +358,11 @@ def fftshift(x, axes=None, name=None):
|
||||
x = _ops.convert_to_tensor(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.shape.ndims))
|
||||
shift = [int(dim // 2) for dim in x.shape]
|
||||
shift = _array_ops.shape(x) // 2
|
||||
elif isinstance(axes, int):
|
||||
shift = int(x.shape[axes] // 2)
|
||||
shift = _array_ops.shape(x)[axes] // 2
|
||||
else:
|
||||
shift = [int((x.shape[ax]) // 2) for ax in axes]
|
||||
shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
|
||||
|
||||
return manip_ops.roll(x, shift, axes, name)
|
||||
|
||||
@ -399,11 +399,11 @@ def ifftshift(x, axes=None, name=None):
|
||||
x = _ops.convert_to_tensor(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.shape.ndims))
|
||||
shift = [-int(dim // 2) for dim in x.shape]
|
||||
shift = -(_array_ops.shape(x) // 2)
|
||||
elif isinstance(axes, int):
|
||||
shift = -int(x.shape[axes] // 2)
|
||||
shift = -(_array_ops.shape(x)[axes] // 2)
|
||||
else:
|
||||
shift = [-int(x.shape[ax] // 2) for ax in axes]
|
||||
shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
|
||||
|
||||
return manip_ops.roll(x, shift, axes, name)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user