[TF:XLA] Remove superfluous calls to .value in fft_test
This allows running the test in TF 2.0 mode PiperOrigin-RevId: 254443675
This commit is contained in:
parent
f68778cb77
commit
92d160f610
@ -131,15 +131,19 @@ class FFTTest(xla_test.XLATestCase):
|
||||
signal.ifft3d)
|
||||
|
||||
def testRFFT(self):
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
|
||||
lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
|
||||
|
||||
def _to_expected(x):
|
||||
return np.fft.rfft(x, n=x.shape[-1])
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft(x, fft_length=[x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_1D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
def testRFFT2D(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft2d(
|
||||
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
|
||||
return signal.rfft2d(x, fft_length=[x.shape[-2], x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_2D, np.real,
|
||||
@ -153,8 +157,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft3d(
|
||||
x,
|
||||
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
|
||||
x, fft_length=[x.shape[-3], x.shape[-2], x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
@ -168,17 +171,14 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||
])
|
||||
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
def testIRFFT(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
|
||||
return signal.irfft(x, fft_length=[2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
|
||||
@ -187,8 +187,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
def testIRFFT2D(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft2d(
|
||||
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
|
||||
return signal.irfft2d(x, fft_length=[x.shape[-2], 2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_2D,
|
||||
@ -212,10 +211,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
|
||||
])
|
||||
x, fft_length=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||
|
||||
@ -235,10 +231,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||
])
|
||||
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user