[XLA:Python] Add NextAfter to Python client.

Add an explicit Cython source version to custom call test to suppress a build warning.

PiperOrigin-RevId: 275882502
Change-Id: Ief34e3ab4fcf64b8e88a4de6ef7481fea9485bd8
This commit is contained in:
Peter Hawkins 2019-10-21 11:11:26 -07:00 committed by TensorFlower Gardener
parent a281b6c87f
commit bfbce5aeae
4 changed files with 12 additions and 0 deletions

View File

@ -1,3 +1,4 @@
# cython: language_level=2
# distutils: language = c++
# Test case for defining a XLA custom call target in Cython, and registering

View File

@ -654,6 +654,7 @@ PYBIND11_MODULE(xla_extension, m) {
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota));
ops.def("Map", &Map);
ops.def("NextAfter", &NextAfter);
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
py::arg("token"), py::arg("shape_with_layout"),
py::arg("outfeed_config") = "");

View File

@ -1678,6 +1678,7 @@ _BINARY_OPS = [
'ShiftRightLogical',
'Atan2',
'Complex',
'NextAfter',
]
_OTHER_OPS = [

View File

@ -1421,6 +1421,15 @@ class SingleOpTest(ComputationTest):
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
rtol=1e-4)
def testNextAfter(self):
c = self._NewComputation()
c.NextAfter(
c.Constant(np.array([1, 2], dtype=np.float32)),
c.Constant(np.array([2, 1], dtype=np.float32)))
out = self._Execute(c, ())
eps = np.finfo(np.float32).eps
np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out)
class EmbeddedComputationsTest(ComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps)."""