[XLA:Python] Implement np.nextafter for bfloat16 extension.

Should fix test case failure in https://github.com/google/jax/pull/3309 after a jaxlib release.

The implementation is a port of the implementation in xla/client/lib.math.cc.

PiperOrigin-RevId: 315303126
Change-Id: I0bdccbb224e74d45663b41581c67de53ee2b77b3
This commit is contained in:
Peter Hawkins 2020-06-08 10:30:43 -07:00 committed by TensorFlower Gardener
parent 60d63428b1
commit cb8342c4eb
2 changed files with 65 additions and 5 deletions

View File

@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = {
nullptr, // nb_and
nullptr, // nb_xor
nullptr, // nb_or
PyBfloat16_Int, // nb_int
nullptr, // reserved
PyBfloat16_Float, // nb_float
PyBfloat16_Int, // nb_int
nullptr, // reserved
PyBfloat16_Float, // nb_float
nullptr, // nb_inplace_add
nullptr, // nb_inplace_subtract
@ -1213,7 +1213,44 @@ struct LogicalXor {
}
};
// TODO(phawkins): implement nextafter, spacing
struct NextAfter {
bfloat16 operator()(bfloat16 from, bfloat16 to) {
uint16_t from_as_int, to_as_int;
const uint16_t sign_mask = 1 << 15;
float from_as_float(from), to_as_float(to);
memcpy(&from_as_int, &from, sizeof(bfloat16));
memcpy(&to_as_int, &to, sizeof(bfloat16));
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
return bfloat16(std::numeric_limits<float>::quiet_NaN());
}
if (from_as_int == to_as_int) {
return to;
}
if (from_as_float == 0) {
if (to_as_float == 0) {
return to;
} else {
// Smallest subnormal signed like `to`.
uint16_t out_int = (to_as_int & sign_mask) | 1;
bfloat16 out;
memcpy(&out, &out_int, sizeof(bfloat16));
return out;
}
}
uint16_t from_sign = from_as_int & sign_mask;
uint16_t to_sign = to_as_int & sign_mask;
uint16_t from_abs = from_as_int & ~sign_mask;
uint16_t to_abs = to_as_int & ~sign_mask;
uint16_t magnitude_adjustment =
(from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
uint16_t out_int = from_as_int + magnitude_adjustment;
bfloat16 out;
memcpy(&out, &out_int, sizeof(bfloat16));
return out;
}
};
// TODO(phawkins): implement spacing
} // namespace ufuncs
@ -1467,7 +1504,9 @@ bool Initialize() {
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
"ceil") &&
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
"trunc");
"trunc") &&
RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
numpy.get(), "nextafter");
return ok;
}

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import itertools
import math
from absl.testing import absltest
@ -398,6 +399,26 @@ class Bfloat16NumPyTest(parameterized.TestCase):
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
absltest.main()