[XLA:Python] Fix bug where bfloat16 dtype extension's hash value was not initialized correctly.
We were (implicitly) initializing the hash value to 0, when -1 is the correct initialization for a PyArray_Descr's hash value. Will fix https://github.com/google/jax/issues/4651 when incorporated into a jaxlib release. PiperOrigin-RevId: 338088352 Change-Id: Ia9d72cbf88b8301b11c7a3bacbdba0848883d0ce
This commit is contained in:
parent
7e5dc369b8
commit
51c47283e7
@ -396,25 +396,30 @@ PyTypeObject PyBfloat16_Type = {
|
||||
PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
|
||||
|
||||
PyArray_Descr NPyBfloat16_Descr = {
|
||||
PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type, // typeobj
|
||||
PyObject_HEAD_INIT(nullptr) //
|
||||
/*typeobj=*/
|
||||
(&PyBfloat16_Type),
|
||||
// We must register bfloat16 with a kind other than "f", because numpy
|
||||
// considers two types with the same kind and size to be equal, but
|
||||
// float16 != bfloat16.
|
||||
// The downside of this is that NumPy scalar promotion does not work with
|
||||
// bfloat16 values.
|
||||
'V', // kind
|
||||
/*kind=*/'V',
|
||||
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
|
||||
// character is unique.
|
||||
'E', // type
|
||||
'=', // byteorder
|
||||
NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, // hasobject
|
||||
0, // type_num
|
||||
sizeof(bfloat16), // elsize
|
||||
alignof(bfloat16), // alignment
|
||||
nullptr, // subarray
|
||||
nullptr, // fields
|
||||
nullptr, // names
|
||||
&NPyBfloat16_ArrFuncs, // f
|
||||
/*type=*/'E',
|
||||
/*byteorder=*/'=',
|
||||
/*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
|
||||
/*type_num=*/0,
|
||||
/*elsize=*/sizeof(bfloat16),
|
||||
/*alignment=*/alignof(bfloat16),
|
||||
/*subarray=*/nullptr,
|
||||
/*fields=*/nullptr,
|
||||
/*names=*/nullptr,
|
||||
/*f=*/&NPyBfloat16_ArrFuncs,
|
||||
/*metadata=*/nullptr,
|
||||
/*c_metadata=*/nullptr,
|
||||
/*hash=*/-1, // -1 means "not computed yet".
|
||||
};
|
||||
|
||||
// Implementations of NumPy array methods.
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
|
||||
@ -254,6 +255,15 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
def testDtype(self):
|
||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
||||
|
||||
def testDeepCopyDoesNotAlterHash(self):
|
||||
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
||||
# value of the type descriptor is not initialized correctly, a deep copy
|
||||
# can change the type hash.
|
||||
dtype = np.dtype(bfloat16)
|
||||
h = hash(dtype)
|
||||
_ = copy.deepcopy(dtype)
|
||||
self.assertEqual(h, hash(dtype))
|
||||
|
||||
def testArray(self):
|
||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||
self.assertEqual(bfloat16, x.dtype)
|
||||
|
@ -387,6 +387,9 @@ PyArray_Descr NPyBfloat16_Descr = {
|
||||
nullptr, // fields
|
||||
nullptr, // names
|
||||
&NPyBfloat16_ArrFuncs, // f
|
||||
nullptr, // metadata
|
||||
nullptr, // c_metadata
|
||||
-1, // hash
|
||||
};
|
||||
|
||||
// Registered numpy type ID. Global variable populated by the registration code.
|
||||
|
Loading…
Reference in New Issue
Block a user