[XLA:Python] Fix Numpy deprecation warning for use of np.object.
Will fix https://github.com/google/jax/issues/4424 when included in a jaxlib. PiperOrigin-RevId: 334611917 Change-Id: I9dc76f812d4744567b62d4b5cf385ded375c4986
This commit is contained in:
parent
fa75523767
commit
f61e1203bd
@ -193,8 +193,8 @@ XLA_ELEMENT_TYPE_TO_DTYPE = {
|
|||||||
PrimitiveType.F64: np.dtype('float64'),
|
PrimitiveType.F64: np.dtype('float64'),
|
||||||
PrimitiveType.C64: np.dtype('complex64'),
|
PrimitiveType.C64: np.dtype('complex64'),
|
||||||
PrimitiveType.C128: np.dtype('complex128'),
|
PrimitiveType.C128: np.dtype('complex128'),
|
||||||
PrimitiveType.TUPLE: np.dtype(np.object),
|
PrimitiveType.TUPLE: np.dtype(np.object_),
|
||||||
PrimitiveType.TOKEN: np.dtype(np.object),
|
PrimitiveType.TOKEN: np.dtype(np.object_),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
|
# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
|
||||||
|
Loading…
Reference in New Issue
Block a user