[XLA:Python] Allow exporting PRED arrays as uint8 using DLPack.

There's no corresponding import functionality, because DLPack doesn't have a representation for booleans.

Fixes https://github.com/google/jax/issues/4719

PiperOrigin-RevId: 351617946
Change-Id: Ib6244be6f72c272a02d44e2e30f44d76e16bd7a7
This commit is contained in:
Peter Hawkins 2021-01-13 10:42:04 -08:00 committed by TensorFlower Gardener
parent 066e3dfa19
commit 78fdd635ab
2 changed files with 8 additions and 3 deletions

View File

@ -102,6 +102,7 @@ StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
case BF16:
return DLDataType{kDLBfloat, 16, 1};
case PRED:
return DLDataType{kDLUInt, 8, 1};
case C64:
case C128:
default:

View File

@ -57,7 +57,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
float_dtypes = [np.float32]
complex_dtypes = [np.complex64]
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
dlpack_dtypes = int_dtypes + float_dtypes
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_]
class ComputationTest(parameterized.TestCase):
"""Base class for running an XLA Computation through the local client."""
@ -1949,14 +1949,18 @@ def TestFactory(xla_backend, cloud_tpu=False):
for take_ownership in [False, True])
# pyformat: enable
def testRoundTrip(self, dtype, shape, take_ownership):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
if dtype == np.bool_:
x = np.random.randint(0, 2, size=shape).astype(np.bool_)
else:
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
buffer = self.backend.buffer_from_pyval(x)
dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
buffer, take_ownership=take_ownership)
del buffer # Free "buffer" to make sure dlt retains ownership.
self.assertEqual(type(dlt).__name__, "PyCapsule")
y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
np.testing.assert_array_equal(x, y.to_py())
np.testing.assert_array_equal(
x.astype(np.uint8) if x == np.bool_ else x, y.to_py())
def testTensorsCanBeConsumedOnceOnly(self):
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)