[XLA:Python] Add NextAfter to Python client.
Add an explicit Cython source version to custom call test to suppress a build warning. PiperOrigin-RevId: 275882502 Change-Id: Ief34e3ab4fcf64b8e88a4de6ef7481fea9485bd8
This commit is contained in:
parent
a281b6c87f
commit
bfbce5aeae
tensorflow/compiler/xla/python
@ -1,3 +1,4 @@
|
||||
# cython: language_level=2
|
||||
# distutils: language = c++
|
||||
|
||||
# Test case for defining a XLA custom call target in Cython, and registering
|
||||
|
@ -654,6 +654,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
ops.def("Iota",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota));
|
||||
ops.def("Map", &Map);
|
||||
ops.def("NextAfter", &NextAfter);
|
||||
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
|
||||
py::arg("token"), py::arg("shape_with_layout"),
|
||||
py::arg("outfeed_config") = "");
|
||||
|
@ -1678,6 +1678,7 @@ _BINARY_OPS = [
|
||||
'ShiftRightLogical',
|
||||
'Atan2',
|
||||
'Complex',
|
||||
'NextAfter',
|
||||
]
|
||||
|
||||
_OTHER_OPS = [
|
||||
|
@ -1421,6 +1421,15 @@ class SingleOpTest(ComputationTest):
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
|
||||
def testNextAfter(self):
|
||||
c = self._NewComputation()
|
||||
c.NextAfter(
|
||||
c.Constant(np.array([1, 2], dtype=np.float32)),
|
||||
c.Constant(np.array([2, 1], dtype=np.float32)))
|
||||
out = self._Execute(c, ())
|
||||
eps = np.finfo(np.float32).eps
|
||||
np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out)
|
||||
|
||||
|
||||
class EmbeddedComputationsTest(ComputationTest):
|
||||
"""Tests for XLA graphs with embedded computations (such as maps)."""
|
||||
|
Loading…
Reference in New Issue
Block a user