[XLA:Python] Implement np.nextafter for bfloat16 extension.
Should fix test case failure in https://github.com/google/jax/pull/3309 after a jaxlib release. The implementation is a port of the implementation in xla/client/lib.math.cc. PiperOrigin-RevId: 315303126 Change-Id: I0bdccbb224e74d45663b41581c67de53ee2b77b3
This commit is contained in:
parent
60d63428b1
commit
cb8342c4eb
@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = {
|
|||||||
nullptr, // nb_and
|
nullptr, // nb_and
|
||||||
nullptr, // nb_xor
|
nullptr, // nb_xor
|
||||||
nullptr, // nb_or
|
nullptr, // nb_or
|
||||||
PyBfloat16_Int, // nb_int
|
PyBfloat16_Int, // nb_int
|
||||||
nullptr, // reserved
|
nullptr, // reserved
|
||||||
PyBfloat16_Float, // nb_float
|
PyBfloat16_Float, // nb_float
|
||||||
|
|
||||||
nullptr, // nb_inplace_add
|
nullptr, // nb_inplace_add
|
||||||
nullptr, // nb_inplace_subtract
|
nullptr, // nb_inplace_subtract
|
||||||
@ -1213,7 +1213,44 @@ struct LogicalXor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(phawkins): implement nextafter, spacing
|
struct NextAfter {
|
||||||
|
bfloat16 operator()(bfloat16 from, bfloat16 to) {
|
||||||
|
uint16_t from_as_int, to_as_int;
|
||||||
|
const uint16_t sign_mask = 1 << 15;
|
||||||
|
float from_as_float(from), to_as_float(to);
|
||||||
|
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
||||||
|
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
||||||
|
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
|
||||||
|
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
||||||
|
}
|
||||||
|
if (from_as_int == to_as_int) {
|
||||||
|
return to;
|
||||||
|
}
|
||||||
|
if (from_as_float == 0) {
|
||||||
|
if (to_as_float == 0) {
|
||||||
|
return to;
|
||||||
|
} else {
|
||||||
|
// Smallest subnormal signed like `to`.
|
||||||
|
uint16_t out_int = (to_as_int & sign_mask) | 1;
|
||||||
|
bfloat16 out;
|
||||||
|
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint16_t from_sign = from_as_int & sign_mask;
|
||||||
|
uint16_t to_sign = to_as_int & sign_mask;
|
||||||
|
uint16_t from_abs = from_as_int & ~sign_mask;
|
||||||
|
uint16_t to_abs = to_as_int & ~sign_mask;
|
||||||
|
uint16_t magnitude_adjustment =
|
||||||
|
(from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
|
||||||
|
uint16_t out_int = from_as_int + magnitude_adjustment;
|
||||||
|
bfloat16 out;
|
||||||
|
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(phawkins): implement spacing
|
||||||
|
|
||||||
} // namespace ufuncs
|
} // namespace ufuncs
|
||||||
|
|
||||||
@ -1467,7 +1504,9 @@ bool Initialize() {
|
|||||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
|
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
|
||||||
"ceil") &&
|
"ceil") &&
|
||||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
|
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
|
||||||
"trunc");
|
"trunc") &&
|
||||||
|
RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
|
||||||
|
numpy.get(), "nextafter");
|
||||||
|
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
@ -398,6 +399,26 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
np.testing.assert_equal(exp1, exp2)
|
np.testing.assert_equal(exp1, exp2)
|
||||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||||
|
|
||||||
|
def testNextAfter(self):
|
||||||
|
one = np.array(1., dtype=bfloat16)
|
||||||
|
two = np.array(2., dtype=bfloat16)
|
||||||
|
zero = np.array(0., dtype=bfloat16)
|
||||||
|
nan = np.array(np.nan, dtype=bfloat16)
|
||||||
|
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||||
|
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||||
|
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||||
|
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||||
|
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||||
|
smallest_denormal = float.fromhex("1.0p-133")
|
||||||
|
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||||
|
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||||
|
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||||
|
np.testing.assert_equal(
|
||||||
|
np.nextafter(
|
||||||
|
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||||
|
np.nextafter(
|
||||||
|
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user