[XLA:Python] Allow exporting PRED arrays as uint8 using DLPack.
There's no corresponding import functionality, because DLPack doesn't have a representation for booleans. Fixes https://github.com/google/jax/issues/4719 PiperOrigin-RevId: 351617946 Change-Id: Ib6244be6f72c272a02d44e2e30f44d76e16bd7a7
This commit is contained in:
parent
066e3dfa19
commit
78fdd635ab
@ -102,6 +102,7 @@ StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
|
||||
case BF16:
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
case PRED:
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
case C64:
|
||||
case C128:
|
||||
default:
|
||||
|
@ -57,7 +57,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
float_dtypes = [np.float32]
|
||||
complex_dtypes = [np.complex64]
|
||||
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
|
||||
dlpack_dtypes = int_dtypes + float_dtypes
|
||||
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_]
|
||||
|
||||
class ComputationTest(parameterized.TestCase):
|
||||
"""Base class for running an XLA Computation through the local client."""
|
||||
@ -1949,14 +1949,18 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
for take_ownership in [False, True])
|
||||
# pyformat: enable
|
||||
def testRoundTrip(self, dtype, shape, take_ownership):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
if dtype == np.bool_:
|
||||
x = np.random.randint(0, 2, size=shape).astype(np.bool_)
|
||||
else:
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
buffer = self.backend.buffer_from_pyval(x)
|
||||
dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
buffer, take_ownership=take_ownership)
|
||||
del buffer # Free "buffer" to make sure dlt retains ownership.
|
||||
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
||||
y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
|
||||
np.testing.assert_array_equal(x, y.to_py())
|
||||
np.testing.assert_array_equal(
|
||||
x.astype(np.uint8) if x == np.bool_ else x, y.to_py())
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user