[XLA:Python] Fix bug where bfloat16 dtype extension's hash value was not initialized correctly.
We were (implicitly) initializing the hash value to 0, when -1 is the correct initialization for a PyArray_Descr's hash value. Will fix https://github.com/google/jax/issues/4651 when incorporated into a jaxlib release. PiperOrigin-RevId: 338088352 Change-Id: Ia9d72cbf88b8301b11c7a3bacbdba0848883d0ce
This commit is contained in:
parent
7e5dc369b8
commit
51c47283e7
@ -396,25 +396,30 @@ PyTypeObject PyBfloat16_Type = {
|
|||||||
PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
|
PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
|
||||||
|
|
||||||
PyArray_Descr NPyBfloat16_Descr = {
|
PyArray_Descr NPyBfloat16_Descr = {
|
||||||
PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type, // typeobj
|
PyObject_HEAD_INIT(nullptr) //
|
||||||
|
/*typeobj=*/
|
||||||
|
(&PyBfloat16_Type),
|
||||||
// We must register bfloat16 with a kind other than "f", because numpy
|
// We must register bfloat16 with a kind other than "f", because numpy
|
||||||
// considers two types with the same kind and size to be equal, but
|
// considers two types with the same kind and size to be equal, but
|
||||||
// float16 != bfloat16.
|
// float16 != bfloat16.
|
||||||
// The downside of this is that NumPy scalar promotion does not work with
|
// The downside of this is that NumPy scalar promotion does not work with
|
||||||
// bfloat16 values.
|
// bfloat16 values.
|
||||||
'V', // kind
|
/*kind=*/'V',
|
||||||
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
|
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
|
||||||
// character is unique.
|
// character is unique.
|
||||||
'E', // type
|
/*type=*/'E',
|
||||||
'=', // byteorder
|
/*byteorder=*/'=',
|
||||||
NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, // hasobject
|
/*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
|
||||||
0, // type_num
|
/*type_num=*/0,
|
||||||
sizeof(bfloat16), // elsize
|
/*elsize=*/sizeof(bfloat16),
|
||||||
alignof(bfloat16), // alignment
|
/*alignment=*/alignof(bfloat16),
|
||||||
nullptr, // subarray
|
/*subarray=*/nullptr,
|
||||||
nullptr, // fields
|
/*fields=*/nullptr,
|
||||||
nullptr, // names
|
/*names=*/nullptr,
|
||||||
&NPyBfloat16_ArrFuncs, // f
|
/*f=*/&NPyBfloat16_ArrFuncs,
|
||||||
|
/*metadata=*/nullptr,
|
||||||
|
/*c_metadata=*/nullptr,
|
||||||
|
/*hash=*/-1, // -1 means "not computed yet".
|
||||||
};
|
};
|
||||||
|
|
||||||
// Implementations of NumPy array methods.
|
// Implementations of NumPy array methods.
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -254,6 +255,15 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
def testDtype(self):
|
def testDtype(self):
|
||||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
||||||
|
|
||||||
|
def testDeepCopyDoesNotAlterHash(self):
|
||||||
|
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
||||||
|
# value of the type descriptor is not initialized correctly, a deep copy
|
||||||
|
# can change the type hash.
|
||||||
|
dtype = np.dtype(bfloat16)
|
||||||
|
h = hash(dtype)
|
||||||
|
_ = copy.deepcopy(dtype)
|
||||||
|
self.assertEqual(h, hash(dtype))
|
||||||
|
|
||||||
def testArray(self):
|
def testArray(self):
|
||||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||||
self.assertEqual(bfloat16, x.dtype)
|
self.assertEqual(bfloat16, x.dtype)
|
||||||
|
@ -387,6 +387,9 @@ PyArray_Descr NPyBfloat16_Descr = {
|
|||||||
nullptr, // fields
|
nullptr, // fields
|
||||||
nullptr, // names
|
nullptr, // names
|
||||||
&NPyBfloat16_ArrFuncs, // f
|
&NPyBfloat16_ArrFuncs, // f
|
||||||
|
nullptr, // metadata
|
||||||
|
nullptr, // c_metadata
|
||||||
|
-1, // hash
|
||||||
};
|
};
|
||||||
|
|
||||||
// Registered numpy type ID. Global variable populated by the registration code.
|
// Registered numpy type ID. Global variable populated by the registration code.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user