diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index e48475b7a85..9e38769168d 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = { nullptr, // nb_and nullptr, // nb_xor nullptr, // nb_or - PyBfloat16_Int, // nb_int - nullptr, // reserved - PyBfloat16_Float, // nb_float + PyBfloat16_Int, // nb_int + nullptr, // reserved + PyBfloat16_Float, // nb_float nullptr, // nb_inplace_add nullptr, // nb_inplace_subtract @@ -1213,7 +1213,44 @@ struct LogicalXor { } }; -// TODO(phawkins): implement nextafter, spacing +struct NextAfter { + bfloat16 operator()(bfloat16 from, bfloat16 to) { + uint16_t from_as_int, to_as_int; + const uint16_t sign_mask = 1 << 15; + float from_as_float(from), to_as_float(to); + memcpy(&from_as_int, &from, sizeof(bfloat16)); + memcpy(&to_as_int, &to, sizeof(bfloat16)); + if (std::isnan(from_as_float) || std::isnan(to_as_float)) { + return bfloat16(std::numeric_limits::quiet_NaN()); + } + if (from_as_int == to_as_int) { + return to; + } + if (from_as_float == 0) { + if (to_as_float == 0) { + return to; + } else { + // Smallest subnormal signed like `to`. + uint16_t out_int = (to_as_int & sign_mask) | 1; + bfloat16 out; + memcpy(&out, &out_int, sizeof(bfloat16)); + return out; + } + } + uint16_t from_sign = from_as_int & sign_mask; + uint16_t to_sign = to_as_int & sign_mask; + uint16_t from_abs = from_as_int & ~sign_mask; + uint16_t to_abs = to_as_int & ~sign_mask; + uint16_t magnitude_adjustment = + (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001; + uint16_t out_int = from_as_int + magnitude_adjustment; + bfloat16 out; + memcpy(&out, &out_int, sizeof(bfloat16)); + return out; + } +}; + +// TODO(phawkins): implement spacing } // namespace ufuncs @@ -1467,7 +1504,9 @@ bool Initialize() { RegisterUFunc>(numpy.get(), "ceil") && RegisterUFunc>(numpy.get(), - "trunc"); + "trunc") && + RegisterUFunc>( + numpy.get(), "nextafter"); return ok; } diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py index 51421a3655e..4c4f8c28d3f 100644 --- a/tensorflow/compiler/xla/python/bfloat16_test.py +++ b/tensorflow/compiler/xla/python/bfloat16_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import itertools import math from absl.testing import absltest @@ -398,6 +399,26 @@ class Bfloat16NumPyTest(parameterized.TestCase): np.testing.assert_equal(exp1, exp2) numpy_assert_allclose(mant1, mant2, rtol=1e-2) + def testNextAfter(self): + one = np.array(1., dtype=bfloat16) + two = np.array(2., dtype=bfloat16) + zero = np.array(0., dtype=bfloat16) + nan = np.array(np.nan, dtype=bfloat16) + np.testing.assert_equal(np.nextafter(one, two) - one, epsilon) + np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2) + np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True) + np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True) + np.testing.assert_equal(np.nextafter(one, one), one) + smallest_denormal = float.fromhex("1.0p-133") + np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) + np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) + for a, b in itertools.permutations([0., -0., nan], 2): + np.testing.assert_equal( + np.nextafter( + np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)), + np.nextafter( + np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16))) + if __name__ == "__main__": absltest.main()