From e7ff5a82de7e2f8cc173ac1ae67a01b65c11bf6f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 3 Feb 2021 07:55:36 -0800 Subject: [PATCH] Register {bool,int8,uint8}->bfloat16 conversions as safe for the NumPy bfloat16 extension. Fixes some breakage when using the bfloat16 extension under NumPy 1.20.0. PiperOrigin-RevId: 355396647 Change-Id: I2f8b8b5b4d679d01516861192778f5ccadaeb958 --- tensorflow/python/lib/core/bfloat16.cc | 86 +++++++++++++-------- tensorflow/python/lib/core/bfloat16_test.py | 19 +++++ 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 8d3518619ae..c86f2844926 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -697,22 +697,17 @@ void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, } // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy -// type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be -// safely coerced to T. +// type corresponding to 'T'. template -bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) { - if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16, - NPyCast) < 0) { +bool RegisterBfloat16Cast(int numpy_type) { + PyArray_Descr* descr = PyArray_DescrFromType(numpy_type); + if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast) < 0) { return false; } if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type, NPyCast) < 0) { return false; } - if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type, - NPY_NOSCALAR) < 0) { - return false; - } return true; } @@ -1366,63 +1361,90 @@ bool Initialize() { } // Register casts - if (!RegisterBfloat16Cast(NPY_HALF, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_HALF)) { return false; } - if (!RegisterBfloat16Cast(NPY_FLOAT, /*cast_is_safe=*/true)) { + + if (!RegisterBfloat16Cast(NPY_FLOAT)) { return false; } - if (!RegisterBfloat16Cast(NPY_DOUBLE, /*cast_is_safe=*/true)) { + if (!RegisterBfloat16Cast(NPY_DOUBLE)) { return false; } - if (!RegisterBfloat16Cast(NPY_BOOL, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_BOOL)) { return false; } - if (!RegisterBfloat16Cast(NPY_UINT8, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_UINT8)) { return false; } - if (!RegisterBfloat16Cast(NPY_UINT16, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_UINT16)) { return false; } - if (!RegisterBfloat16Cast(NPY_UINT, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_UINT)) { return false; } - if (!RegisterBfloat16Cast(NPY_ULONG, // NOLINT - /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_ULONG)) { // NOLINT return false; } - if (!RegisterBfloat16Cast( // NOLINT - NPY_ULONGLONG, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_ULONGLONG)) { // NOLINT return false; } - if (!RegisterBfloat16Cast(NPY_UINT64, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_UINT64)) { return false; } - if (!RegisterBfloat16Cast(NPY_INT8, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_INT8)) { return false; } - if (!RegisterBfloat16Cast(NPY_INT16, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_INT16)) { return false; } - if (!RegisterBfloat16Cast(NPY_INT, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_INT)) { return false; } - if (!RegisterBfloat16Cast(NPY_LONG, // NOLINT - /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_LONG)) { // NOLINT return false; } - if (!RegisterBfloat16Cast( // NOLINT - NPY_LONGLONG, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast(NPY_LONGLONG)) { // NOLINT return false; } // Following the numpy convention. imag part is dropped when converting to // float. - if (!RegisterBfloat16Cast>(NPY_COMPLEX64, - /*cast_is_safe=*/true)) { + if (!RegisterBfloat16Cast>(NPY_COMPLEX64)) { return false; } - if (!RegisterBfloat16Cast>(NPY_COMPLEX128, - /*cast_is_safe=*/true)) { + if (!RegisterBfloat16Cast>(NPY_COMPLEX128)) { + return false; + } + + // Safe casts from bfloat16 to other types + if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) < + 0) { + return false; + } + if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) < + 0) { + return false; + } + if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX64, NPY_NOSCALAR) < + 0) { + return false; + } + if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX128, + NPY_NOSCALAR) < 0) { + return false; + } + + // Safe casts to bfloat16 from other types + if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UINT8), npy_bfloat16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_INT8), npy_bfloat16, + NPY_NOSCALAR) < 0) { return false; } diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 0bd5f0cfadf..eb3dedaf04e 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -292,6 +292,25 @@ class Bfloat16NumPyTest(parameterized.TestCase): b = np.array([82432], bfloat16) self.assertFalse(a.__eq__(b)) + def testCanCast(self): + allowed_casts = [ + (np.bool_, bfloat16), + (np.int8, bfloat16), + (np.uint8, bfloat16), + (bfloat16, np.float32), + (bfloat16, np.float64), + (bfloat16, np.complex64), + (bfloat16, np.complex128), + ] + all_dtypes = [ + np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, + np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32, + np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong + ] + for d in all_dtypes: + self.assertEqual((bfloat16, d) in allowed_casts, np.can_cast(bfloat16, d)) + self.assertEqual((d, bfloat16) in allowed_casts, np.can_cast(d, bfloat16)) + def testCasts(self): for dtype in [ np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,