[XLA:Python] Expose Fft in XLA Python client.
PiperOrigin-RevId: 246326482
This commit is contained in:
parent
64c4f34e98
commit
ca3d787e90
@ -339,6 +339,14 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
ops.def("DynamicUpdateSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
|
||||
&DynamicUpdateSlice));
|
||||
|
||||
ops.def("Fft", &Fft);
|
||||
py::enum_<FftType>(m, "FftType")
|
||||
.value("FFT", FftType::FFT)
|
||||
.value("IFFT", FftType::IFFT)
|
||||
.value("RFFT", FftType::RFFT)
|
||||
.value("IRFFT", FftType::IRFFT);
|
||||
|
||||
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
|
||||
py::arg("dimension_numbers"), py::arg("slice_sizes"));
|
||||
ops.def("GetTupleElement", &GetTupleElement);
|
||||
|
@ -1650,6 +1650,13 @@ class ComputationBuilder(object):
|
||||
return ops.Scatter(a, scatter_indices, updates,
|
||||
update_computation.computation, dimension_numbers)
|
||||
|
||||
def Fft(self, operand, fft_type, fft_lengths):
|
||||
"""Enqueues a FFT operation onto the computation."""
|
||||
return ops.Fft(operand, fft_type, fft_lengths)
|
||||
|
||||
|
||||
FftType = _xla.FftType
|
||||
|
||||
|
||||
_UNARY_OPS = [
|
||||
'Not',
|
||||
|
@ -1280,6 +1280,33 @@ class SingleOpTest(ComputationTest):
|
||||
expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
|
||||
np.testing.assert_allclose(g, expected, rtol=1e-4)
|
||||
|
||||
def testFft(self):
|
||||
shape = [2, 3, 4, 5]
|
||||
rng = np.random.RandomState(0)
|
||||
a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
|
||||
a = a.astype(np.complex64)
|
||||
# FFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
# IFFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
# RFFT
|
||||
b = rng.randn(*shape).astype(np.float32)
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
# IRFFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
|
||||
|
||||
class EmbeddedComputationsTest(ComputationTest):
|
||||
"""Tests for XLA graphs with embedded computations (such as maps)."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user