[XLA:Python] Fix bugs in DLPack implementation.

* when exporting XLA arrays to DLPack, we were producing incorrect strides in bytes rather than number of elements. Happily we had another bug which meant that we always overwrote the strides with nullptr, which means "major-to-minor", which is the case for all JAX CPU and GPU arrays.
* when importing DLPack tensors to XLA, we didn't handle sized-1 dimensions correctly when converting an explicit stride to a layout. This was masked by the previous bug when round-tripping tensors from XLA to itself via DLPack.
* PyTorch changes the DLPack capsule name when it takes ownership. Handle this case in the capsule destructor.

PiperOrigin-RevId: 292218978
Change-Id: I07f8985ffc51cd3090cecb72fc491f4fc936458a
This commit is contained in:
Peter Hawkins 2020-01-29 14:28:27 -08:00 committed by TensorFlower Gardener
parent 314c578ddd
commit c094952fca
2 changed files with 32 additions and 8 deletions
tensorflow/compiler/xla/python

View File

@ -159,7 +159,7 @@ std::vector<int64> StridesForShape(const Shape& shape) {
CHECK(shape.has_layout());
strides.resize(shape.dimensions_size());
int64 stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
int64 stride = 1;
for (int i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= shape.dimensions(i);
@ -172,8 +172,15 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
CHECK_EQ(dims.size(), strides.size());
std::vector<int64> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
absl::c_sort(minor_to_major,
[&](int a, int b) { return strides[a] < strides[b]; });
absl::c_sort(minor_to_major, [&](int a, int b) {
if (strides[a] < strides[b]) {
return true;
}
if (strides[a] > strides[b]) {
return false;
}
return dims[a] == 1 && dims[b] != 1;
});
int64 stride = 1;
for (int64 d : minor_to_major) {
if (strides[d] != stride) {
@ -267,13 +274,19 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
pack->strides = StridesForShape(buffer->on_host_shape());
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.strides = nullptr;
dt.byte_offset = 0;
py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
[](PyObject* obj) {
DLPackTensorDeleter(static_cast<DLManagedTensor*>(
PyCapsule_GetPointer(obj, kDlTensorCapsuleName)));
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
if (dlmt) {
DLPackTensorDeleter(dlmt);
} else {
// The tensor has been deleted. Clear any error from
// PyCapsule_GetPointer.
PyErr_Clear();
}
});
TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
@ -302,7 +315,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
std::vector<int64> minor_to_major;
if (dlmt->dl_tensor.strides) {
if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) {
absl::Span<int64 const> strides(
reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
dlmt->dl_tensor.ndim);

View File

@ -2065,7 +2065,18 @@ class DLPackTest(parameterized.TestCase):
dtype,
"shape":
shape
} for dtype in dlpack_dtypes for shape in [(), (1,), (2, 3), (4, 1, 2)])
} for dtype in dlpack_dtypes for shape in [
(),
(1,),
(2, 3),
(2, 0),
(0, 7),
(4, 1, 2),
(2, 1, 3),
(2, 4, 1),
(3, 1),
(1, 3),
])
def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
backend = xla_client.get_local_backend()