[TF:XLA] Remove superfluous calls to .value in fft_test

This allows running the test in TF 2.0 mode

PiperOrigin-RevId: 254443675
This commit is contained in:
Benjamin Kramer 2019-06-21 12:05:02 -07:00 committed by TensorFlower Gardener
parent f68778cb77
commit 92d160f610

View File

@ -131,15 +131,19 @@ class FFTTest(xla_test.XLATestCase):
signal.ifft3d)
def testRFFT(self):
self._VerifyFftMethod(
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
def _to_expected(x):
return np.fft.rfft(x, n=x.shape[-1])
def _tf_fn(x):
return signal.rfft(x, fft_length=[x.shape[-1]])
self._VerifyFftMethod(INNER_DIMS_1D, np.real, _to_expected, _tf_fn)
def testRFFT2D(self):
def _tf_fn(x):
return signal.rfft2d(
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
return signal.rfft2d(x, fft_length=[x.shape[-2], x.shape[-1]])
self._VerifyFftMethod(
INNER_DIMS_2D, np.real,
@ -153,8 +157,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.rfft3d(
x,
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
x, fft_length=[x.shape[-3], x.shape[-2], x.shape[-1]])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
@ -168,17 +171,14 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.rfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
def testIRFFT(self):
def _tf_fn(x):
return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
return signal.irfft(x, fft_length=[2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
@ -187,8 +187,7 @@ class FFTTest(xla_test.XLATestCase):
def testIRFFT2D(self):
def _tf_fn(x):
return signal.irfft2d(
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
return signal.irfft2d(x, fft_length=[x.shape[-2], 2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(
INNER_DIMS_2D,
@ -212,10 +211,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
])
x, fft_length=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
@ -235,10 +231,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)