[XLA:Python] Fix bug where bfloat16 dtype extension's hash value was not initialized correctly.

We were (implicitly) initializing the hash value to 0, when -1 is the correct initialization for a PyArray_Descr's hash value.

Will fix https://github.com/google/jax/issues/4651 when incorporated into a jaxlib release.

PiperOrigin-RevId: 338088352
Change-Id: Ia9d72cbf88b8301b11c7a3bacbdba0848883d0ce
This commit is contained in:
Peter Hawkins 2020-10-20 10:38:12 -07:00 committed by TensorFlower Gardener
parent 7e5dc369b8
commit 51c47283e7
3 changed files with 30 additions and 12 deletions

View File

@ -396,25 +396,30 @@ PyTypeObject PyBfloat16_Type = {
PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
PyArray_Descr NPyBfloat16_Descr = {
PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type, // typeobj
PyObject_HEAD_INIT(nullptr) //
/*typeobj=*/
(&PyBfloat16_Type),
// We must register bfloat16 with a kind other than "f", because numpy
// considers two types with the same kind and size to be equal, but
// float16 != bfloat16.
// The downside of this is that NumPy scalar promotion does not work with
// bfloat16 values.
'V', // kind
/*kind=*/'V',
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
'E', // type
'=', // byteorder
NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, // hasobject
0, // type_num
sizeof(bfloat16), // elsize
alignof(bfloat16), // alignment
nullptr, // subarray
nullptr, // fields
nullptr, // names
&NPyBfloat16_ArrFuncs, // f
/*type=*/'E',
/*byteorder=*/'=',
/*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
/*type_num=*/0,
/*elsize=*/sizeof(bfloat16),
/*alignment=*/alignof(bfloat16),
/*subarray=*/nullptr,
/*fields=*/nullptr,
/*names=*/nullptr,
/*f=*/&NPyBfloat16_ArrFuncs,
/*metadata=*/nullptr,
/*c_metadata=*/nullptr,
/*hash=*/-1, // -1 means "not computed yet".
};
// Implementations of NumPy array methods.

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import math
@ -254,6 +255,15 @@ class Bfloat16NumPyTest(parameterized.TestCase):
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)

View File

@ -387,6 +387,9 @@ PyArray_Descr NPyBfloat16_Descr = {
nullptr, // fields
nullptr, // names
&NPyBfloat16_ArrFuncs, // f
nullptr, // metadata
nullptr, // c_metadata
-1, // hash
};
// Registered numpy type ID. Global variable populated by the registration code.