From bfbce5aeaebc660b272397e51c14c67c0d9ba360 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 21 Oct 2019 11:11:26 -0700 Subject: [PATCH] [XLA:Python] Add NextAfter to Python client. Add an explicit Cython source version to custom call test to suppress a build warning. PiperOrigin-RevId: 275882502 Change-Id: Ief34e3ab4fcf64b8e88a4de6ef7481fea9485bd8 --- tensorflow/compiler/xla/python/custom_call_for_test.pyx | 1 + tensorflow/compiler/xla/python/xla.cc | 1 + tensorflow/compiler/xla/python/xla_client.py | 1 + tensorflow/compiler/xla/python/xla_client_test.py | 9 +++++++++ 4 files changed, 12 insertions(+) diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx index 4f7c4c3e5a8..cc8b9f0bc64 100644 --- a/tensorflow/compiler/xla/python/custom_call_for_test.pyx +++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx @@ -1,3 +1,4 @@ +# cython: language_level=2 # distutils: language = c++ # Test case for defining a XLA custom call target in Cython, and registering diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index cff68f5ca3a..01ff619a462 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -654,6 +654,7 @@ PYBIND11_MODULE(xla_extension, m) { ops.def("Iota", static_cast(&Iota)); ops.def("Map", &Map); + ops.def("NextAfter", &NextAfter); ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), py::arg("token"), py::arg("shape_with_layout"), py::arg("outfeed_config") = ""); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 29b6278ca9e..9c9554cf3f9 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1678,6 +1678,7 @@ _BINARY_OPS = [ 'ShiftRightLogical', 'Atan2', 'Complex', + 'NextAfter', ] _OTHER_OPS = [ diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 0c248d57815..05f7d876507 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1421,6 +1421,15 @@ class SingleOpTest(ComputationTest): self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4) + def testNextAfter(self): + c = self._NewComputation() + c.NextAfter( + c.Constant(np.array([1, 2], dtype=np.float32)), + c.Constant(np.array([2, 1], dtype=np.float32))) + out = self._Execute(c, ()) + eps = np.finfo(np.float32).eps + np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out) + class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps)."""