Register {bool,int8,uint8}->bfloat16 conversions as safe for the NumPy bfloat16 extension.

Fixes some breakage when using the bfloat16 extension under NumPy 1.20.0.

PiperOrigin-RevId: 355396647
Change-Id: I2f8b8b5b4d679d01516861192778f5ccadaeb958
This commit is contained in:
Peter Hawkins 2021-02-03 07:55:36 -08:00 committed by TensorFlower Gardener
parent 794b7af16b
commit e7ff5a82de
2 changed files with 73 additions and 32 deletions

View File

@ -697,22 +697,17 @@ void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
} }
// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
// type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be // type corresponding to 'T'.
// safely coerced to T.
template <typename T> template <typename T>
bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) { bool RegisterBfloat16Cast(int numpy_type) {
if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16, PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
NPyCast<T, bfloat16>) < 0) { if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast<T, bfloat16>) < 0) {
return false; return false;
} }
if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type, if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
NPyCast<bfloat16, T>) < 0) { NPyCast<bfloat16, T>) < 0) {
return false; return false;
} }
if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type,
NPY_NOSCALAR) < 0) {
return false;
}
return true; return true;
} }
@ -1366,63 +1361,90 @@ bool Initialize() {
} }
// Register casts // Register casts
if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<float>(NPY_FLOAT, /*cast_is_safe=*/true)) {
if (!RegisterBfloat16Cast<float>(NPY_FLOAT)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<double>(NPY_DOUBLE, /*cast_is_safe=*/true)) { if (!RegisterBfloat16Cast<double>(NPY_DOUBLE)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<bool>(NPY_BOOL, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<bool>(NPY_BOOL)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<uint8>(NPY_UINT8, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<uint8>(NPY_UINT8)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<uint16>(NPY_UINT16)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG, // NOLINT if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG)) { // NOLINT
/*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<unsigned long long>( // NOLINT if (!RegisterBfloat16Cast<unsigned long long>(NPY_ULONGLONG)) { // NOLINT
NPY_ULONGLONG, /*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<uint64>(NPY_UINT64)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<int8>(NPY_INT8, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<int8>(NPY_INT8)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<int16>(NPY_INT16)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<int>(NPY_INT)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<long>(NPY_LONG, // NOLINT if (!RegisterBfloat16Cast<long>(NPY_LONG)) { // NOLINT
/*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<long long>( // NOLINT if (!RegisterBfloat16Cast<long long>(NPY_LONGLONG)) { // NOLINT
NPY_LONGLONG, /*cast_is_safe=*/false)) {
return false; return false;
} }
// Following the numpy convention. imag part is dropped when converting to // Following the numpy convention. imag part is dropped when converting to
// float. // float.
if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64, if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64)) {
/*cast_is_safe=*/true)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128, if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128)) {
/*cast_is_safe=*/true)) { return false;
}
// Safe casts from bfloat16 to other types
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) <
0) {
return false;
}
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) <
0) {
return false;
}
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX64, NPY_NOSCALAR) <
0) {
return false;
}
if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX128,
NPY_NOSCALAR) < 0) {
return false;
}
// Safe casts to bfloat16 from other types
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UINT8), npy_bfloat16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_INT8), npy_bfloat16,
NPY_NOSCALAR) < 0) {
return false; return false;
} }

View File

@ -292,6 +292,25 @@ class Bfloat16NumPyTest(parameterized.TestCase):
b = np.array([82432], bfloat16) b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b)) self.assertFalse(a.__eq__(b))
def testCanCast(self):
allowed_casts = [
(np.bool_, bfloat16),
(np.int8, bfloat16),
(np.uint8, bfloat16),
(bfloat16, np.float32),
(bfloat16, np.float64),
(bfloat16, np.complex64),
(bfloat16, np.complex128),
]
all_dtypes = [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]
for d in all_dtypes:
self.assertEqual((bfloat16, d) in allowed_casts, np.can_cast(bfloat16, d))
self.assertEqual((d, bfloat16) in allowed_casts, np.can_cast(d, bfloat16))
def testCasts(self): def testCasts(self):
for dtype in [ for dtype in [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,