[XLA:Python] Add DLPack import/export support to the XLA Python client.
This allows JAX to communicate on-device arrays with other libraries, such as PyTorch and CuPy. PiperOrigin-RevId: 290845329 Change-Id: Idd99d81533159bc2ad0c5177b69ac7f30315cb1a
This commit is contained in:
parent
470239ee94
commit
fc1f6fdf94
@ -34,6 +34,7 @@ py_test(
|
||||
":xla_client",
|
||||
":xla_extension",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
] + xla_py_test_deps(),
|
||||
)
|
||||
|
||||
@ -248,6 +249,34 @@ py_test(
|
||||
] + xla_py_test_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
hdrs = ["dlpack.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":local_client",
|
||||
":shared_device_buffer",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"//tensorflow/stream_executor/cuda:cuda_platform_id",
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@dlpack",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_gpu",
|
||||
values = {"define": "xla_python_enable_gpu=true"},
|
||||
@ -266,6 +295,7 @@ pybind_extension(
|
||||
module_name = "xla_extension",
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":dlpack",
|
||||
":local_client",
|
||||
":shared_device_buffer",
|
||||
":python_ref_manager",
|
||||
|
339
tensorflow/compiler/xla/python/dlpack.cc
Normal file
339
tensorflow/compiler/xla/python/dlpack.cc
Normal file
@ -0,0 +1,339 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/host/host_platform_id.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
const char* const kDlTensorCapsuleName = "dltensor";
|
||||
|
||||
struct DLPackTensor {
|
||||
std::shared_ptr<SharedDeviceBuffer> buffer;
|
||||
std::vector<int64> shape;
|
||||
std::vector<int64> strides;
|
||||
DLManagedTensor tensor;
|
||||
};
|
||||
|
||||
void DLPackTensorDeleter(DLManagedTensor* t) {
|
||||
if (t) {
|
||||
delete static_cast<DLPackTensor*>(t->manager_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
|
||||
switch (type) {
|
||||
case PRED:
|
||||
return DLDataType{kDLInt, 1, 1};
|
||||
case S8:
|
||||
return DLDataType{kDLInt, 8, 1};
|
||||
case S16:
|
||||
return DLDataType{kDLInt, 16, 1};
|
||||
case S32:
|
||||
return DLDataType{kDLInt, 32, 1};
|
||||
case S64:
|
||||
return DLDataType{kDLInt, 64, 1};
|
||||
case U8:
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
case U16:
|
||||
return DLDataType{kDLUInt, 16, 1};
|
||||
case U32:
|
||||
return DLDataType{kDLUInt, 32, 1};
|
||||
case U64:
|
||||
return DLDataType{kDLUInt, 64, 1};
|
||||
case F16:
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
case F32:
|
||||
return DLDataType{kDLFloat, 32, 1};
|
||||
case F64:
|
||||
return DLDataType{kDLFloat, 64, 1};
|
||||
case BF16:
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
case C64:
|
||||
case C128:
|
||||
default:
|
||||
return Unimplemented("XLA type %s has no DLPack equivalent",
|
||||
PrimitiveType_Name(type));
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
|
||||
if (type.lanes != 1) {
|
||||
return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
|
||||
type.lanes);
|
||||
}
|
||||
switch (type.code) {
|
||||
case kDLInt:
|
||||
switch (type.bits) {
|
||||
case 1:
|
||||
return PRED;
|
||||
case 8:
|
||||
return S8;
|
||||
case 16:
|
||||
return S16;
|
||||
case 32:
|
||||
return S32;
|
||||
case 64:
|
||||
return S64;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Invalid or unsupported DLPack integer width: %d bits",
|
||||
type.bits);
|
||||
}
|
||||
case kDLUInt:
|
||||
switch (type.bits) {
|
||||
case 1:
|
||||
return PRED;
|
||||
case 8:
|
||||
return U8;
|
||||
case 16:
|
||||
return U16;
|
||||
case 32:
|
||||
return U32;
|
||||
case 64:
|
||||
return U64;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Invalid or unsupported DLPack unsigned integer width: %d bits",
|
||||
type.bits);
|
||||
}
|
||||
case kDLFloat:
|
||||
switch (type.bits) {
|
||||
case 16:
|
||||
return F16;
|
||||
case 32:
|
||||
return F32;
|
||||
case 64:
|
||||
return F64;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Invalid or unsupported DLPack float width: %d bits", type.bits);
|
||||
}
|
||||
case kDLBfloat:
|
||||
switch (type.bits) {
|
||||
case 16:
|
||||
return BF16;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the strides for `shape`.
|
||||
std::vector<int64> StridesForShape(const Shape& shape) {
|
||||
std::vector<int64> strides;
|
||||
CHECK(shape.IsArray());
|
||||
CHECK(shape.has_layout());
|
||||
|
||||
strides.resize(shape.dimensions_size());
|
||||
int64 stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
|
||||
for (int i : shape.layout().minor_to_major()) {
|
||||
strides.at(i) = stride;
|
||||
stride *= shape.dimensions(i);
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
||||
absl::Span<int64 const> strides) {
|
||||
CHECK_EQ(dims.size(), strides.size());
|
||||
std::vector<int64> minor_to_major(dims.size());
|
||||
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
|
||||
absl::c_sort(minor_to_major,
|
||||
[&](int a, int b) { return strides[a] < strides[b]; });
|
||||
int64 stride = 1;
|
||||
for (int64 d : minor_to_major) {
|
||||
if (strides[d] != stride) {
|
||||
return Unimplemented(
|
||||
"Only DLPack tensors with trivial (compact) striding are supported; "
|
||||
"i.e., tensors whose striding represents a transposition of the "
|
||||
"underlying buffer but not broadcasting. Dimensions were: [%s], "
|
||||
"strides were [%s].",
|
||||
absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
|
||||
}
|
||||
stride *= dims[d];
|
||||
}
|
||||
return minor_to_major;
|
||||
}
|
||||
|
||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
|
||||
const se::Platform* platform =
|
||||
device.local_device_state()->executor()->platform();
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
return kDLCPU;
|
||||
} else if (platform->id() == se::cuda::kCudaPlatformId) {
|
||||
return kDLGPU;
|
||||
}
|
||||
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
||||
device.DebugString());
|
||||
}
|
||||
|
||||
StatusOr<DLContext> DLContextForDevice(const Device& device) {
|
||||
DLContext context;
|
||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||
context.device_id = device.local_device_state()->device_ordinal();
|
||||
return context;
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Device>> DeviceForDLContext(
|
||||
const PyLocalClient& client, const DLContext& context) {
|
||||
se::Platform::Id platform_id;
|
||||
switch (context.device_type) {
|
||||
case kDLCPU:
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
break;
|
||||
case kDLGPU:
|
||||
platform_id = se::cuda::kCudaPlatformId;
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||
context.device_type);
|
||||
}
|
||||
auto it = absl::c_find_if(
|
||||
client.local_devices(), [&](const std::shared_ptr<Device>& device) {
|
||||
return device->local_device_state()->executor()->platform()->id() ==
|
||||
platform_id &&
|
||||
device->local_device_state()->device_ordinal() ==
|
||||
context.device_id;
|
||||
});
|
||||
if (it == client.local_devices().end()) {
|
||||
return InvalidArgument(
|
||||
"No matching device found for DLPack device_type %d device_id %d",
|
||||
context.device_type, context.device_id);
|
||||
}
|
||||
return *it;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
|
||||
auto pack = absl::make_unique<DLPackTensor>();
|
||||
pack->buffer = buffer->DeviceBuffer();
|
||||
if (!pack->buffer) {
|
||||
return InvalidArgument(
|
||||
"Cannot convert deleted/invalid buffer to DLPack tensor.");
|
||||
}
|
||||
pack->tensor.manager_ctx = pack.get();
|
||||
pack->tensor.deleter = DLPackTensorDeleter;
|
||||
DLTensor& dt = pack->tensor.dl_tensor;
|
||||
if (buffer->on_device_shape().IsTuple()) {
|
||||
return Unimplemented(
|
||||
"unsafe_buffer_pointer is not implemented for tuple "
|
||||
"buffers.");
|
||||
}
|
||||
TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
|
||||
dt.data = pack->buffer->device_memory().front().opaque();
|
||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device()));
|
||||
dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal();
|
||||
dt.ndim = buffer->on_host_shape().dimensions_size();
|
||||
TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType(
|
||||
buffer->on_host_shape().element_type()));
|
||||
|
||||
pack->shape = std::vector<int64>(buffer->on_host_shape().dimensions().begin(),
|
||||
buffer->on_host_shape().dimensions().end());
|
||||
pack->strides = StridesForShape(buffer->on_host_shape());
|
||||
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
|
||||
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
|
||||
dt.strides = nullptr;
|
||||
dt.byte_offset = 0;
|
||||
|
||||
py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
|
||||
[](PyObject* obj) {
|
||||
DLPackTensorDeleter(static_cast<DLManagedTensor*>(
|
||||
PyCapsule_GetPointer(obj, kDlTensorCapsuleName)));
|
||||
});
|
||||
|
||||
TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
|
||||
return capsule;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client) {
|
||||
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
|
||||
return InvalidArgument(
|
||||
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||
"Note that a DLPack tensor may be consumed at most once.",
|
||||
absl::string_view(tensor.name()));
|
||||
}
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
|
||||
if (dlmt->dl_tensor.ndim < 0) {
|
||||
return InvalidArgument(
|
||||
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
|
||||
dlmt->dl_tensor.ndim);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
|
||||
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
|
||||
absl::Span<int64 const> dimensions(
|
||||
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
|
||||
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
|
||||
DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
|
||||
|
||||
std::vector<int64> minor_to_major;
|
||||
if (dlmt->dl_tensor.strides) {
|
||||
absl::Span<int64 const> strides(
|
||||
reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
|
||||
dlmt->dl_tensor.ndim);
|
||||
TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
|
||||
} else {
|
||||
minor_to_major.resize(dlmt->dl_tensor.ndim);
|
||||
std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
|
||||
}
|
||||
Shape shape =
|
||||
ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
|
||||
se::DeviceMemoryBase buffer(
|
||||
static_cast<char*>(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset,
|
||||
ShapeUtil::ByteSizeOf(shape));
|
||||
|
||||
std::function<void()> on_delete_callback;
|
||||
if (dlmt->deleter) {
|
||||
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
|
||||
}
|
||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
||||
|
||||
// We have taken ownership of the array inside the capsule; make sure the
|
||||
// capsule it cannot be used again.
|
||||
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
|
||||
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
|
||||
return absl::make_unique<PyLocalBuffer>(shape, shape,
|
||||
std::move(device_buffer),
|
||||
std::move(client), std::move(device));
|
||||
}
|
||||
|
||||
} // namespace xla
|
31
tensorflow/compiler/xla/python/dlpack.h
Normal file
31
tensorflow/compiler/xla/python/dlpack.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer);
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
@ -141,8 +141,10 @@ class PyLocalClient {
|
||||
|
||||
int device_count() const { return devices_.size(); }
|
||||
int local_device_count() const { return local_devices_.size(); }
|
||||
const std::vector<std::shared_ptr<Device>>& devices() { return devices_; }
|
||||
const std::vector<std::shared_ptr<Device>>& local_devices() {
|
||||
const std::vector<std::shared_ptr<Device>>& devices() const {
|
||||
return devices_;
|
||||
}
|
||||
const std::vector<std::shared_ptr<Device>>& local_devices() const {
|
||||
return local_devices_;
|
||||
}
|
||||
const std::map<int, std::shared_ptr<Device>>& id_to_device() const {
|
||||
|
@ -44,6 +44,7 @@ class LocalDeviceState {
|
||||
bool asynchronous, bool allow_event_reuse);
|
||||
virtual ~LocalDeviceState();
|
||||
|
||||
se::StreamExecutor* executor() const { return executor_; }
|
||||
// StreamExecutor (local) device ordinal.
|
||||
int device_ordinal() const { return executor_->device_ordinal(); }
|
||||
|
||||
|
@ -122,7 +122,8 @@ SharedDeviceBuffer::MakeTuple(
|
||||
return std::make_shared<SharedDeviceBuffer>(
|
||||
allocator, device_ordinal,
|
||||
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
|
||||
std::move(children), std::move(definition_event));
|
||||
std::move(children), std::move(definition_event),
|
||||
/*on_delete_callback=*/nullptr);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::shared_ptr<SharedDeviceBuffer>>
|
||||
@ -179,12 +180,14 @@ SharedDeviceBuffer::SharedDeviceBuffer(
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event)
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
||||
std::function<void()> on_delete_callback)
|
||||
: allocator_(allocator),
|
||||
device_ordinal_(device_ordinal),
|
||||
device_memory_(device_memory.begin(), device_memory.end()),
|
||||
children_(std::move(children)),
|
||||
definition_event_(std::move(definition_event)) {}
|
||||
definition_event_(std::move(definition_event)),
|
||||
on_delete_callback_(std::move(on_delete_callback)) {}
|
||||
|
||||
SharedDeviceBuffer::SharedDeviceBuffer(
|
||||
absl::Span<se::OwningDeviceMemory> device_memory,
|
||||
@ -211,6 +214,9 @@ SharedDeviceBuffer::~SharedDeviceBuffer() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (on_delete_callback_) {
|
||||
on_delete_callback_();
|
||||
}
|
||||
}
|
||||
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
|
@ -120,6 +120,9 @@ class SharedDeviceBuffer {
|
||||
}
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
int device_ordinal() const { return device_ordinal_; }
|
||||
absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
|
||||
return device_memory_;
|
||||
}
|
||||
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
|
||||
return device_memory_;
|
||||
}
|
||||
@ -131,7 +134,8 @@ class SharedDeviceBuffer {
|
||||
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
||||
std::function<void()> on_delete_callback);
|
||||
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
@ -152,6 +156,9 @@ class SharedDeviceBuffer {
|
||||
// single-stream execution case where events are not necessary for buffer
|
||||
// event sequencing.
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event_;
|
||||
|
||||
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
|
||||
std::function<void()> on_delete_callback_;
|
||||
};
|
||||
|
||||
// Populates 'events' with the set of buffer definition events for all buffers
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
@ -652,6 +653,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("SetSharding", &XlaBuilder::SetSharding)
|
||||
.def("ClearSharding", &XlaBuilder::ClearSharding);
|
||||
|
||||
m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
|
||||
m.def("DLPackManagedTensorToBuffer", DLPackManagedTensorToBuffer);
|
||||
|
||||
// ops submodule, containing free functions that add operators to an
|
||||
// XlaBuilder.
|
||||
py::module ops = m.def_submodule("ops", "XLA operations");
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -23,12 +24,12 @@ import itertools
|
||||
import threading
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.python import custom_call_for_test
|
||||
from tensorflow.compiler.xla.python import xla_client
|
||||
|
||||
|
||||
bfloat16 = xla_client.bfloat16
|
||||
|
||||
|
||||
@ -1420,24 +1421,24 @@ class SingleOpTest(ComputationTest):
|
||||
# FFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4)
|
||||
# IFFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4)
|
||||
# RFFT
|
||||
b = rng.randn(*shape).astype(np.float32)
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4)
|
||||
# IRFFT
|
||||
c = self._NewComputation()
|
||||
c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
|
||||
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)),
|
||||
rtol=1e-4)
|
||||
self._ExecuteAndCompareClose(
|
||||
c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4)
|
||||
|
||||
def testNextAfter(self):
|
||||
c = self._NewComputation()
|
||||
@ -1454,8 +1455,8 @@ class SingleOpTest(ComputationTest):
|
||||
b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677])
|
||||
c = self._NewComputation()
|
||||
c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x))
|
||||
expected = np.array([0.98923271, 0.48575411, 0.57952568, 0.12579775,
|
||||
0.96989155])
|
||||
expected = np.array(
|
||||
[0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
|
||||
self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4)
|
||||
|
||||
|
||||
@ -1974,7 +1975,7 @@ class ErrorTest(ComputationTest):
|
||||
def TestFun():
|
||||
return c.Build().Compile(compile_options=options)
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, r".*Invalid argument shape.*"
|
||||
r"expected s32\[\], got f32\[\].*", TestFun)
|
||||
|
||||
@ -1988,7 +1989,7 @@ class ErrorTest(ComputationTest):
|
||||
return xla_client.execute_with_python_values(c.Build().Compile(),
|
||||
[self.f32_scalar_2])
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, r"Invalid argument: Argument does not match.*"
|
||||
r"want s32\[\], got f32\[\].*", TestFun)
|
||||
|
||||
@ -2031,5 +2032,47 @@ class SetShardingTest(ComputationTest):
|
||||
np.testing.assert_allclose(ans, 4.14)
|
||||
|
||||
|
||||
dlpack_dtypes = [
|
||||
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64, np.float16, np.float32, np.float64, bfloat16
|
||||
]
|
||||
|
||||
|
||||
class DLPackTest(parameterized.TestCase):
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name":
|
||||
"_{}[{}]".format(dtype.__name__, ",".join(map(str, shape))),
|
||||
"dtype":
|
||||
dtype,
|
||||
"shape":
|
||||
shape
|
||||
} for dtype in dlpack_dtypes for shape in [(), (1,), (2, 3), (4, 1, 2)])
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
backend = xla_client.get_local_backend()
|
||||
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
|
||||
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
|
||||
del buffer # Free "buffer" to make sure dlt retains ownership.
|
||||
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
||||
y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
|
||||
np.testing.assert_array_equal(x, y.to_py())
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
|
||||
backend = xla_client.get_local_backend()
|
||||
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
|
||||
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
|
||||
|
||||
def ConsumeDLPackTensor():
|
||||
_ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
|
||||
|
||||
ConsumeDLPackTensor()
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -26,6 +26,7 @@ load("//third_party/FXdiv:workspace.bzl", FXdiv = "repo")
|
||||
load("//third_party/aws:workspace.bzl", aws = "repo")
|
||||
load("//third_party/clog:workspace.bzl", clog = "repo")
|
||||
load("//third_party/cpuinfo:workspace.bzl", cpuinfo = "repo")
|
||||
load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
|
||||
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
|
||||
load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo")
|
||||
load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo")
|
||||
@ -48,6 +49,7 @@ def initialize_third_party():
|
||||
aws()
|
||||
clog()
|
||||
cpuinfo()
|
||||
dlpack()
|
||||
flatbuffers()
|
||||
hexagon_nn()
|
||||
highwayhash()
|
||||
|
14
third_party/dlpack/BUILD.bazel
vendored
Normal file
14
third_party/dlpack/BUILD.bazel
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
# Description:
|
||||
# DLPack is a protocol for sharing arrays between deep learning frameworks.
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "dlpack",
|
||||
hdrs = [
|
||||
"include/dlpack/dlpack.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
15
third_party/dlpack/workspace.bzl
vendored
Normal file
15
third_party/dlpack/workspace.bzl
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
"""DLPack is a protocol for sharing arrays between deep learning frameworks."""
|
||||
|
||||
load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "dlpack",
|
||||
strip_prefix = "dlpack-3efc489b55385936531a06ff83425b719387ec63",
|
||||
sha256 = "b59586ce69bcf3efdbf3cf4803fadfeaae4948044e2b8d89cf912194cf28f233",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz",
|
||||
"https://github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/dlpack:BUILD.bazel",
|
||||
)
|
Loading…
Reference in New Issue
Block a user