[XLA:Python] Implement np.nextafter for bfloat16 extension.
Should fix test case failure in https://github.com/google/jax/pull/3309 after a jaxlib release. The implementation is a port of the implementation in xla/client/lib.math.cc. PiperOrigin-RevId: 315303126 Change-Id: I0bdccbb224e74d45663b41581c67de53ee2b77b3
This commit is contained in:
parent
60d63428b1
commit
cb8342c4eb
@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = {
|
||||
nullptr, // nb_and
|
||||
nullptr, // nb_xor
|
||||
nullptr, // nb_or
|
||||
PyBfloat16_Int, // nb_int
|
||||
nullptr, // reserved
|
||||
PyBfloat16_Float, // nb_float
|
||||
PyBfloat16_Int, // nb_int
|
||||
nullptr, // reserved
|
||||
PyBfloat16_Float, // nb_float
|
||||
|
||||
nullptr, // nb_inplace_add
|
||||
nullptr, // nb_inplace_subtract
|
||||
@ -1213,7 +1213,44 @@ struct LogicalXor {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(phawkins): implement nextafter, spacing
|
||||
struct NextAfter {
|
||||
bfloat16 operator()(bfloat16 from, bfloat16 to) {
|
||||
uint16_t from_as_int, to_as_int;
|
||||
const uint16_t sign_mask = 1 << 15;
|
||||
float from_as_float(from), to_as_float(to);
|
||||
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
||||
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
||||
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
|
||||
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
||||
}
|
||||
if (from_as_int == to_as_int) {
|
||||
return to;
|
||||
}
|
||||
if (from_as_float == 0) {
|
||||
if (to_as_float == 0) {
|
||||
return to;
|
||||
} else {
|
||||
// Smallest subnormal signed like `to`.
|
||||
uint16_t out_int = (to_as_int & sign_mask) | 1;
|
||||
bfloat16 out;
|
||||
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||
return out;
|
||||
}
|
||||
}
|
||||
uint16_t from_sign = from_as_int & sign_mask;
|
||||
uint16_t to_sign = to_as_int & sign_mask;
|
||||
uint16_t from_abs = from_as_int & ~sign_mask;
|
||||
uint16_t to_abs = to_as_int & ~sign_mask;
|
||||
uint16_t magnitude_adjustment =
|
||||
(from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
|
||||
uint16_t out_int = from_as_int + magnitude_adjustment;
|
||||
bfloat16 out;
|
||||
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(phawkins): implement spacing
|
||||
|
||||
} // namespace ufuncs
|
||||
|
||||
@ -1467,7 +1504,9 @@ bool Initialize() {
|
||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
|
||||
"ceil") &&
|
||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
|
||||
"trunc");
|
||||
"trunc") &&
|
||||
RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
|
||||
numpy.get(), "nextafter");
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -398,6 +399,26 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
np.testing.assert_equal(exp1, exp2)
|
||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||
|
||||
def testNextAfter(self):
|
||||
one = np.array(1., dtype=bfloat16)
|
||||
two = np.array(2., dtype=bfloat16)
|
||||
zero = np.array(0., dtype=bfloat16)
|
||||
nan = np.array(np.nan, dtype=bfloat16)
|
||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||
smallest_denormal = float.fromhex("1.0p-133")
|
||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||
np.testing.assert_equal(
|
||||
np.nextafter(
|
||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||
np.nextafter(
|
||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user