[XLA:Python] Fix assertion failure in debug mode.
PiperOrigin-RevId: 351819163 Change-Id: I4b32ee9d5abba389b5ee1ef3718caac6510cc0ec
This commit is contained in:
parent
5df17e42e5
commit
5efecaf1a1
@ -1984,7 +1984,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
||||
y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
|
||||
np.testing.assert_array_equal(
|
||||
x.astype(np.uint8) if x == np.bool_ else x, y.to_py())
|
||||
x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py())
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user