[XLA:Python] Expose Fft in XLA Python client.

PiperOrigin-RevId: 246326482
This commit is contained in:
Skye Wanderman-Milne 2019-05-02 08:52:30 -07:00 committed by TensorFlower Gardener
parent 64c4f34e98
commit ca3d787e90
3 changed files with 42 additions and 0 deletions

View File

@ -339,6 +339,14 @@ PYBIND11_MODULE(xla_extension, m) {
ops.def("DynamicUpdateSlice",
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
&DynamicUpdateSlice));
ops.def("Fft", &Fft);
py::enum_<FftType>(m, "FftType")
.value("FFT", FftType::FFT)
.value("IFFT", FftType::IFFT)
.value("RFFT", FftType::RFFT)
.value("IRFFT", FftType::IRFFT);
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
py::arg("dimension_numbers"), py::arg("slice_sizes"));
ops.def("GetTupleElement", &GetTupleElement);

View File

@ -1650,6 +1650,13 @@ class ComputationBuilder(object):
return ops.Scatter(a, scatter_indices, updates,
update_computation.computation, dimension_numbers)
def Fft(self, operand, fft_type, fft_lengths):
"""Enqueues a FFT operation onto the computation."""
return ops.Fft(operand, fft_type, fft_lengths)
FftType = _xla.FftType
_UNARY_OPS = [
'Not',

View File

@ -1280,6 +1280,33 @@ class SingleOpTest(ComputationTest):
expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
np.testing.assert_allclose(g, expected, rtol=1e-4)
def testFft(self):
shape = [2, 3, 4, 5]
rng = np.random.RandomState(0)
a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
a = a.astype(np.complex64)
# FFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)),
rtol=1e-4)
# IFFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)),
rtol=1e-4)
# RFFT
b = rng.randn(*shape).astype(np.float32)
c = self._NewComputation()
c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)),
rtol=1e-4)
# IRFFT
c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
rtol=1e-4)
class EmbeddedComputationsTest(ComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps)."""