[XLA:Python] Fix bugs in DLPack implementation.
* when exporting XLA arrays to DLPack, we were producing incorrect strides in bytes rather than number of elements. Happily we had another bug which meant that we always overwrote the strides with nullptr, which means "major-to-minor", which is the case for all JAX CPU and GPU arrays. * when importing DLPack tensors to XLA, we didn't handle sized-1 dimensions correctly when converting an explicit stride to a layout. This was masked by the previous bug when round-tripping tensors from XLA to itself via DLPack. * PyTorch changes the DLPack capsule name when it takes ownership. Handle this case in the capsule destructor. PiperOrigin-RevId: 292218978 Change-Id: I07f8985ffc51cd3090cecb72fc491f4fc936458a
This commit is contained in:
parent
314c578ddd
commit
c094952fca
tensorflow/compiler/xla/python
@ -159,7 +159,7 @@ std::vector<int64> StridesForShape(const Shape& shape) {
|
||||
CHECK(shape.has_layout());
|
||||
|
||||
strides.resize(shape.dimensions_size());
|
||||
int64 stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
|
||||
int64 stride = 1;
|
||||
for (int i : shape.layout().minor_to_major()) {
|
||||
strides.at(i) = stride;
|
||||
stride *= shape.dimensions(i);
|
||||
@ -172,8 +172,15 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
||||
CHECK_EQ(dims.size(), strides.size());
|
||||
std::vector<int64> minor_to_major(dims.size());
|
||||
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
|
||||
absl::c_sort(minor_to_major,
|
||||
[&](int a, int b) { return strides[a] < strides[b]; });
|
||||
absl::c_sort(minor_to_major, [&](int a, int b) {
|
||||
if (strides[a] < strides[b]) {
|
||||
return true;
|
||||
}
|
||||
if (strides[a] > strides[b]) {
|
||||
return false;
|
||||
}
|
||||
return dims[a] == 1 && dims[b] != 1;
|
||||
});
|
||||
int64 stride = 1;
|
||||
for (int64 d : minor_to_major) {
|
||||
if (strides[d] != stride) {
|
||||
@ -267,13 +274,19 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
|
||||
pack->strides = StridesForShape(buffer->on_host_shape());
|
||||
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
|
||||
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
|
||||
dt.strides = nullptr;
|
||||
dt.byte_offset = 0;
|
||||
|
||||
py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
|
||||
[](PyObject* obj) {
|
||||
DLPackTensorDeleter(static_cast<DLManagedTensor*>(
|
||||
PyCapsule_GetPointer(obj, kDlTensorCapsuleName)));
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
|
||||
PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
|
||||
if (dlmt) {
|
||||
DLPackTensorDeleter(dlmt);
|
||||
} else {
|
||||
// The tensor has been deleted. Clear any error from
|
||||
// PyCapsule_GetPointer.
|
||||
PyErr_Clear();
|
||||
}
|
||||
});
|
||||
|
||||
TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
|
||||
@ -302,7 +315,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
|
||||
DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
|
||||
|
||||
std::vector<int64> minor_to_major;
|
||||
if (dlmt->dl_tensor.strides) {
|
||||
if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) {
|
||||
absl::Span<int64 const> strides(
|
||||
reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
|
||||
dlmt->dl_tensor.ndim);
|
||||
|
@ -2065,7 +2065,18 @@ class DLPackTest(parameterized.TestCase):
|
||||
dtype,
|
||||
"shape":
|
||||
shape
|
||||
} for dtype in dlpack_dtypes for shape in [(), (1,), (2, 3), (4, 1, 2)])
|
||||
} for dtype in dlpack_dtypes for shape in [
|
||||
(),
|
||||
(1,),
|
||||
(2, 3),
|
||||
(2, 0),
|
||||
(0, 7),
|
||||
(4, 1, 2),
|
||||
(2, 1, 3),
|
||||
(2, 4, 1),
|
||||
(3, 1),
|
||||
(1, 3),
|
||||
])
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
backend = xla_client.get_local_backend()
|
||||
|
Loading…
Reference in New Issue
Block a user