diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index a67358497a6..774def3a72c 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -102,6 +102,7 @@ StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { case BF16: return DLDataType{kDLBfloat, 16, 1}; case PRED: + return DLDataType{kDLUInt, 8, 1}; case C64: case C128: default: diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index bca3ca5907f..197da05f94c 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -57,7 +57,7 @@ def TestFactory(xla_backend, cloud_tpu=False): float_dtypes = [np.float32] complex_dtypes = [np.complex64] standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] - dlpack_dtypes = int_dtypes + float_dtypes + dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] class ComputationTest(parameterized.TestCase): """Base class for running an XLA Computation through the local client.""" @@ -1949,14 +1949,18 @@ def TestFactory(xla_backend, cloud_tpu=False): for take_ownership in [False, True]) # pyformat: enable def testRoundTrip(self, dtype, shape, take_ownership): - x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + if dtype == np.bool_: + x = np.random.randint(0, 2, size=shape).astype(np.bool_) + else: + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) buffer = self.backend.buffer_from_pyval(x) dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( buffer, take_ownership=take_ownership) del buffer # Free "buffer" to make sure dlt retains ownership. self.assertEqual(type(dlt).__name__, "PyCapsule") y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) - np.testing.assert_array_equal(x, y.to_py()) + np.testing.assert_array_equal( + x.astype(np.uint8) if x == np.bool_ else x, y.to_py()) def testTensorsCanBeConsumedOnceOnly(self): x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)