[XLA:Python] Fix assertion failure in debug mode.

PiperOrigin-RevId: 351819163
Change-Id: I4b32ee9d5abba389b5ee1ef3718caac6510cc0ec
This commit is contained in:
Peter Hawkins 2021-01-14 09:48:06 -08:00 committed by TensorFlower Gardener
parent 5df17e42e5
commit 5efecaf1a1

View File

@ -1984,7 +1984,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertEqual(type(dlt).__name__, "PyCapsule")
y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
np.testing.assert_array_equal(
x.astype(np.uint8) if x == np.bool_ else x, y.to_py())
x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py())
def testTensorsCanBeConsumedOnceOnly(self):
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)