diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 12a325ea29b..d84c08af599 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -339,6 +339,14 @@ PYBIND11_MODULE(xla_extension, m) { ops.def("DynamicUpdateSlice", static_cast)>( &DynamicUpdateSlice)); + + ops.def("Fft", &Fft); + py::enum_(m, "FftType") + .value("FFT", FftType::FFT) + .value("IFFT", FftType::IFFT) + .value("RFFT", FftType::RFFT) + .value("IRFFT", FftType::IRFFT); + ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), py::arg("dimension_numbers"), py::arg("slice_sizes")); ops.def("GetTupleElement", &GetTupleElement); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 298cc430902..dbb07ac6761 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1650,6 +1650,13 @@ class ComputationBuilder(object): return ops.Scatter(a, scatter_indices, updates, update_computation.computation, dimension_numbers) + def Fft(self, operand, fft_type, fft_lengths): + """Enqueues a FFT operation onto the computation.""" + return ops.Fft(operand, fft_type, fft_lengths) + + +FftType = _xla.FftType + _UNARY_OPS = [ 'Not', diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 35518f29c2a..0f268f037f0 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1280,6 +1280,33 @@ class SingleOpTest(ComputationTest): expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) np.testing.assert_allclose(g, expected, rtol=1e-4) + def testFft(self): + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)), + rtol=1e-4) + # IFFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), + rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), + rtol=1e-4) + # IRFFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), + rtol=1e-4) + class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps)."""