Register {bool,int8,uint8}->bfloat16 conversions as safe for the NumPy bfloat16 extension.
Fixes some breakage when using the bfloat16 extension under NumPy 1.20.0. PiperOrigin-RevId: 355396647 Change-Id: I2f8b8b5b4d679d01516861192778f5ccadaeb958
This commit is contained in:
parent
794b7af16b
commit
e7ff5a82de
@ -697,22 +697,17 @@ void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
|
// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
|
||||||
// type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be
|
// type corresponding to 'T'.
|
||||||
// safely coerced to T.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
|
bool RegisterBfloat16Cast(int numpy_type) {
|
||||||
if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16,
|
PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
|
||||||
NPyCast<T, bfloat16>) < 0) {
|
if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast<T, bfloat16>) < 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
|
if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
|
||||||
NPyCast<bfloat16, T>) < 0) {
|
NPyCast<bfloat16, T>) < 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type,
|
|
||||||
NPY_NOSCALAR) < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1366,63 +1361,90 @@ bool Initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register casts
|
// Register casts
|
||||||
if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
|
|
||||||
|
if (!RegisterBfloat16Cast<float>(NPY_FLOAT)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) {
|
if (!RegisterBfloat16Cast<double>(NPY_DOUBLE)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<bool>(NPY_BOOL, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<bool>(NPY_BOOL)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<uint8>(NPY_UINT8, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<uint8>(NPY_UINT8)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG, // NOLINT
|
if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG)) { // NOLINT
|
||||||
/*cast_is_safe=*/false)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<unsigned long long>( // NOLINT
|
if (!RegisterBfloat16Cast<unsigned long long>(NPY_ULONGLONG)) { // NOLINT
|
||||||
NPY_ULONGLONG, /*cast_is_safe=*/false)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<int8>(NPY_INT8, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<int8>(NPY_INT8)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<int16>(NPY_INT16)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<int>(NPY_INT)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<long>(NPY_LONG, // NOLINT
|
if (!RegisterBfloat16Cast<long>(NPY_LONG)) { // NOLINT
|
||||||
/*cast_is_safe=*/false)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<long long>( // NOLINT
|
if (!RegisterBfloat16Cast<long long>(NPY_LONGLONG)) { // NOLINT
|
||||||
NPY_LONGLONG, /*cast_is_safe=*/false)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Following the numpy convention. imag part is dropped when converting to
|
// Following the numpy convention. imag part is dropped when converting to
|
||||||
// float.
|
// float.
|
||||||
if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64,
|
if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64)) {
|
||||||
/*cast_is_safe=*/true)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128,
|
if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128)) {
|
||||||
/*cast_is_safe=*/true)) {
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safe casts from bfloat16 to other types
|
||||||
|
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) <
|
||||||
|
0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) <
|
||||||
|
0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX64, NPY_NOSCALAR) <
|
||||||
|
0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX128,
|
||||||
|
NPY_NOSCALAR) < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safe casts to bfloat16 from other types
|
||||||
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16,
|
||||||
|
NPY_NOSCALAR) < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UINT8), npy_bfloat16,
|
||||||
|
NPY_NOSCALAR) < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_INT8), npy_bfloat16,
|
||||||
|
NPY_NOSCALAR) < 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,6 +292,25 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
b = np.array([82432], bfloat16)
|
b = np.array([82432], bfloat16)
|
||||||
self.assertFalse(a.__eq__(b))
|
self.assertFalse(a.__eq__(b))
|
||||||
|
|
||||||
|
def testCanCast(self):
|
||||||
|
allowed_casts = [
|
||||||
|
(np.bool_, bfloat16),
|
||||||
|
(np.int8, bfloat16),
|
||||||
|
(np.uint8, bfloat16),
|
||||||
|
(bfloat16, np.float32),
|
||||||
|
(bfloat16, np.float64),
|
||||||
|
(bfloat16, np.complex64),
|
||||||
|
(bfloat16, np.complex128),
|
||||||
|
]
|
||||||
|
all_dtypes = [
|
||||||
|
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||||
|
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
||||||
|
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
||||||
|
]
|
||||||
|
for d in all_dtypes:
|
||||||
|
self.assertEqual((bfloat16, d) in allowed_casts, np.can_cast(bfloat16, d))
|
||||||
|
self.assertEqual((d, bfloat16) in allowed_casts, np.can_cast(d, bfloat16))
|
||||||
|
|
||||||
def testCasts(self):
|
def testCasts(self):
|
||||||
for dtype in [
|
for dtype in [
|
||||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||||
|
Loading…
Reference in New Issue
Block a user