[XLA:Python] Add DLPack import/export support to the XLA Python client.

This allows JAX to communicate on-device arrays with other libraries, such as PyTorch and CuPy.

PiperOrigin-RevId: 290845329
Change-Id: Idd99d81533159bc2ad0c5177b69ac7f30315cb1a
This commit is contained in:
Peter Hawkins 2020-01-21 16:10:41 -08:00 committed by TensorFlower Gardener
parent 470239ee94
commit fc1f6fdf94
12 changed files with 513 additions and 19 deletions

View File

@ -34,6 +34,7 @@ py_test(
":xla_client", ":xla_client",
":xla_extension", ":xla_extension",
"@absl_py//absl/testing:absltest", "@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
] + xla_py_test_deps(), ] + xla_py_test_deps(),
) )
@ -248,6 +249,34 @@ py_test(
] + xla_py_test_deps(), ] + xla_py_test_deps(),
) )
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":local_client",
":shared_device_buffer",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor/cuda:cuda_platform_id",
"//tensorflow/stream_executor/host:host_platform_id",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@dlpack",
"@pybind11",
],
)
config_setting( config_setting(
name = "enable_gpu", name = "enable_gpu",
values = {"define": "xla_python_enable_gpu=true"}, values = {"define": "xla_python_enable_gpu=true"},
@ -266,6 +295,7 @@ pybind_extension(
module_name = "xla_extension", module_name = "xla_extension",
deps = [ deps = [
":bfloat16", ":bfloat16",
":dlpack",
":local_client", ":local_client",
":shared_device_buffer", ":shared_device_buffer",
":python_ref_manager", ":python_ref_manager",

View File

@ -0,0 +1,339 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/dlpack.h"
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // TF:dlpack
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/host/host_platform_id.h"
#include "tensorflow/stream_executor/platform.h"
namespace py = pybind11;
namespace xla {
namespace {
const char* const kDlTensorCapsuleName = "dltensor";
struct DLPackTensor {
std::shared_ptr<SharedDeviceBuffer> buffer;
std::vector<int64> shape;
std::vector<int64> strides;
DLManagedTensor tensor;
};
void DLPackTensorDeleter(DLManagedTensor* t) {
if (t) {
delete static_cast<DLPackTensor*>(t->manager_ctx);
}
}
StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
switch (type) {
case PRED:
return DLDataType{kDLInt, 1, 1};
case S8:
return DLDataType{kDLInt, 8, 1};
case S16:
return DLDataType{kDLInt, 16, 1};
case S32:
return DLDataType{kDLInt, 32, 1};
case S64:
return DLDataType{kDLInt, 64, 1};
case U8:
return DLDataType{kDLUInt, 8, 1};
case U16:
return DLDataType{kDLUInt, 16, 1};
case U32:
return DLDataType{kDLUInt, 32, 1};
case U64:
return DLDataType{kDLUInt, 64, 1};
case F16:
return DLDataType{kDLFloat, 16, 1};
case F32:
return DLDataType{kDLFloat, 32, 1};
case F64:
return DLDataType{kDLFloat, 64, 1};
case BF16:
return DLDataType{kDLBfloat, 16, 1};
case C64:
case C128:
default:
return Unimplemented("XLA type %s has no DLPack equivalent",
PrimitiveType_Name(type));
}
}
StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
if (type.lanes != 1) {
return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
type.lanes);
}
switch (type.code) {
case kDLInt:
switch (type.bits) {
case 1:
return PRED;
case 8:
return S8;
case 16:
return S16;
case 32:
return S32;
case 64:
return S64;
default:
return Unimplemented(
"Invalid or unsupported DLPack integer width: %d bits",
type.bits);
}
case kDLUInt:
switch (type.bits) {
case 1:
return PRED;
case 8:
return U8;
case 16:
return U16;
case 32:
return U32;
case 64:
return U64;
default:
return Unimplemented(
"Invalid or unsupported DLPack unsigned integer width: %d bits",
type.bits);
}
case kDLFloat:
switch (type.bits) {
case 16:
return F16;
case 32:
return F32;
case 64:
return F64;
default:
return Unimplemented(
"Invalid or unsupported DLPack float width: %d bits", type.bits);
}
case kDLBfloat:
switch (type.bits) {
case 16:
return BF16;
default:
return Unimplemented(
"Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
}
default:
return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
}
}
// Returns the strides for `shape`.
std::vector<int64> StridesForShape(const Shape& shape) {
std::vector<int64> strides;
CHECK(shape.IsArray());
CHECK(shape.has_layout());
strides.resize(shape.dimensions_size());
int64 stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
for (int i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= shape.dimensions(i);
}
return strides;
}
StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
absl::Span<int64 const> strides) {
CHECK_EQ(dims.size(), strides.size());
std::vector<int64> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
absl::c_sort(minor_to_major,
[&](int a, int b) { return strides[a] < strides[b]; });
int64 stride = 1;
for (int64 d : minor_to_major) {
if (strides[d] != stride) {
return Unimplemented(
"Only DLPack tensors with trivial (compact) striding are supported; "
"i.e., tensors whose striding represents a transposition of the "
"underlying buffer but not broadcasting. Dimensions were: [%s], "
"strides were [%s].",
absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
}
stride *= dims[d];
}
return minor_to_major;
}
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
const se::Platform* platform =
device.local_device_state()->executor()->platform();
if (platform->id() == se::host::kHostPlatformId) {
return kDLCPU;
} else if (platform->id() == se::cuda::kCudaPlatformId) {
return kDLGPU;
}
return InvalidArgument("Device %s cannot be used as a DLPack device.",
device.DebugString());
}
StatusOr<DLContext> DLContextForDevice(const Device& device) {
DLContext context;
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
context.device_id = device.local_device_state()->device_ordinal();
return context;
}
StatusOr<std::shared_ptr<Device>> DeviceForDLContext(
const PyLocalClient& client, const DLContext& context) {
se::Platform::Id platform_id;
switch (context.device_type) {
case kDLCPU:
platform_id = se::host::kHostPlatformId;
break;
case kDLGPU:
platform_id = se::cuda::kCudaPlatformId;
break;
default:
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
}
auto it = absl::c_find_if(
client.local_devices(), [&](const std::shared_ptr<Device>& device) {
return device->local_device_state()->executor()->platform()->id() ==
platform_id &&
device->local_device_state()->device_ordinal() ==
context.device_id;
});
if (it == client.local_devices().end()) {
return InvalidArgument(
"No matching device found for DLPack device_type %d device_id %d",
context.device_type, context.device_id);
}
return *it;
}
} // namespace
StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
auto pack = absl::make_unique<DLPackTensor>();
pack->buffer = buffer->DeviceBuffer();
if (!pack->buffer) {
return InvalidArgument(
"Cannot convert deleted/invalid buffer to DLPack tensor.");
}
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter;
DLTensor& dt = pack->tensor.dl_tensor;
if (buffer->on_device_shape().IsTuple()) {
return Unimplemented(
"unsafe_buffer_pointer is not implemented for tuple "
"buffers.");
}
TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
dt.data = pack->buffer->device_memory().front().opaque();
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device()));
dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal();
dt.ndim = buffer->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType(
buffer->on_host_shape().element_type()));
pack->shape = std::vector<int64>(buffer->on_host_shape().dimensions().begin(),
buffer->on_host_shape().dimensions().end());
pack->strides = StridesForShape(buffer->on_host_shape());
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.strides = nullptr;
dt.byte_offset = 0;
py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
[](PyObject* obj) {
DLPackTensorDeleter(static_cast<DLManagedTensor*>(
PyCapsule_GetPointer(obj, kDlTensorCapsuleName)));
});
TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
return capsule;
}
StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client) {
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
return InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
absl::string_view(tensor.name()));
}
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
if (dlmt->dl_tensor.ndim < 0) {
return InvalidArgument(
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
dlmt->dl_tensor.ndim);
}
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
absl::Span<int64 const> dimensions(
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
std::vector<int64> minor_to_major;
if (dlmt->dl_tensor.strides) {
absl::Span<int64 const> strides(
reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
dlmt->dl_tensor.ndim);
TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
} else {
minor_to_major.resize(dlmt->dl_tensor.ndim);
std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
}
Shape shape =
ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
se::DeviceMemoryBase buffer(
static_cast<char*>(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset,
ShapeUtil::ByteSizeOf(shape));
std::function<void()> on_delete_callback;
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
}
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
std::initializer_list<se::DeviceMemoryBase>{buffer},
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
/*definition_event=*/nullptr, std::move(on_delete_callback));
// We have taken ownership of the array inside the capsule; make sure the
// capsule it cannot be used again.
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
return absl::make_unique<PyLocalBuffer>(shape, shape,
std::move(device_buffer),
std::move(client), std::move(device));
}
} // namespace xla

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/python/local_client.h"
namespace xla {
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer);
StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_

View File

@ -141,8 +141,10 @@ class PyLocalClient {
int device_count() const { return devices_.size(); } int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); } int local_device_count() const { return local_devices_.size(); }
const std::vector<std::shared_ptr<Device>>& devices() { return devices_; } const std::vector<std::shared_ptr<Device>>& devices() const {
const std::vector<std::shared_ptr<Device>>& local_devices() { return devices_;
}
const std::vector<std::shared_ptr<Device>>& local_devices() const {
return local_devices_; return local_devices_;
} }
const std::map<int, std::shared_ptr<Device>>& id_to_device() const { const std::map<int, std::shared_ptr<Device>>& id_to_device() const {

View File

@ -44,6 +44,7 @@ class LocalDeviceState {
bool asynchronous, bool allow_event_reuse); bool asynchronous, bool allow_event_reuse);
virtual ~LocalDeviceState(); virtual ~LocalDeviceState();
se::StreamExecutor* executor() const { return executor_; }
// StreamExecutor (local) device ordinal. // StreamExecutor (local) device ordinal.
int device_ordinal() const { return executor_->device_ordinal(); } int device_ordinal() const { return executor_->device_ordinal(); }

View File

@ -122,7 +122,8 @@ SharedDeviceBuffer::MakeTuple(
return std::make_shared<SharedDeviceBuffer>( return std::make_shared<SharedDeviceBuffer>(
allocator, device_ordinal, allocator, device_ordinal,
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()}, std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
std::move(children), std::move(definition_event)); std::move(children), std::move(definition_event),
/*on_delete_callback=*/nullptr);
} }
/* static */ StatusOr<std::shared_ptr<SharedDeviceBuffer>> /* static */ StatusOr<std::shared_ptr<SharedDeviceBuffer>>
@ -179,12 +180,14 @@ SharedDeviceBuffer::SharedDeviceBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal, se::DeviceMemoryAllocator* allocator, int device_ordinal,
absl::Span<se::DeviceMemoryBase const> device_memory, absl::Span<se::DeviceMemoryBase const> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children, std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event) std::shared_ptr<BufferDefinitionEvent> definition_event,
std::function<void()> on_delete_callback)
: allocator_(allocator), : allocator_(allocator),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
device_memory_(device_memory.begin(), device_memory.end()), device_memory_(device_memory.begin(), device_memory.end()),
children_(std::move(children)), children_(std::move(children)),
definition_event_(std::move(definition_event)) {} definition_event_(std::move(definition_event)),
on_delete_callback_(std::move(on_delete_callback)) {}
SharedDeviceBuffer::SharedDeviceBuffer( SharedDeviceBuffer::SharedDeviceBuffer(
absl::Span<se::OwningDeviceMemory> device_memory, absl::Span<se::OwningDeviceMemory> device_memory,
@ -211,6 +214,9 @@ SharedDeviceBuffer::~SharedDeviceBuffer() {
} }
} }
} }
if (on_delete_callback_) {
on_delete_callback_();
}
} }
void GetDeviceBufferDefinitionEvents( void GetDeviceBufferDefinitionEvents(

View File

@ -120,6 +120,9 @@ class SharedDeviceBuffer {
} }
se::DeviceMemoryAllocator* allocator() const { return allocator_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; }
int device_ordinal() const { return device_ordinal_; } int device_ordinal() const { return device_ordinal_; }
absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
return device_memory_;
}
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const { const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
return device_memory_; return device_memory_;
} }
@ -131,7 +134,8 @@ class SharedDeviceBuffer {
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
absl::Span<se::DeviceMemoryBase const> device_memory, absl::Span<se::DeviceMemoryBase const> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children, std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event); std::shared_ptr<BufferDefinitionEvent> definition_event,
std::function<void()> on_delete_callback);
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory, SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children, std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event); std::shared_ptr<BufferDefinitionEvent> definition_event);
@ -152,6 +156,9 @@ class SharedDeviceBuffer {
// single-stream execution case where events are not necessary for buffer // single-stream execution case where events are not necessary for buffer
// event sequencing. // event sequencing.
std::shared_ptr<BufferDefinitionEvent> definition_event_; std::shared_ptr<BufferDefinitionEvent> definition_event_;
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
std::function<void()> on_delete_callback_;
}; };
// Populates 'events' with the set of buffer definition events for all buffers // Populates 'events' with the set of buffer definition events for all buffers

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/python/types.h"
@ -652,6 +653,9 @@ PYBIND11_MODULE(xla_extension, m) {
.def("SetSharding", &XlaBuilder::SetSharding) .def("SetSharding", &XlaBuilder::SetSharding)
.def("ClearSharding", &XlaBuilder::ClearSharding); .def("ClearSharding", &XlaBuilder::ClearSharding);
m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
m.def("DLPackManagedTensorToBuffer", DLPackManagedTensorToBuffer);
// ops submodule, containing free functions that add operators to an // ops submodule, containing free functions that add operators to an
// XlaBuilder. // XlaBuilder.
py::module ops = m.def_submodule("ops", "XLA operations"); py::module ops = m.def_submodule("ops", "XLA operations");

View File

@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -23,12 +24,12 @@ import itertools
import threading import threading
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import custom_call_for_test
from tensorflow.compiler.xla.python import xla_client from tensorflow.compiler.xla.python import xla_client
bfloat16 = xla_client.bfloat16 bfloat16 = xla_client.bfloat16
@ -1420,24 +1421,24 @@ class SingleOpTest(ComputationTest):
# FFT # FFT
c = self._NewComputation() c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)), self._ExecuteAndCompareClose(
rtol=1e-4) c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4)
# IFFT # IFFT
c = self._NewComputation() c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), self._ExecuteAndCompareClose(
rtol=1e-4) c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4)
# RFFT # RFFT
b = rng.randn(*shape).astype(np.float32) b = rng.randn(*shape).astype(np.float32)
c = self._NewComputation() c = self._NewComputation()
c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), self._ExecuteAndCompareClose(
rtol=1e-4) c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4)
# IRFFT # IRFFT
c = self._NewComputation() c = self._NewComputation()
c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), self._ExecuteAndCompareClose(
rtol=1e-4) c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4)
def testNextAfter(self): def testNextAfter(self):
c = self._NewComputation() c = self._NewComputation()
@ -1454,8 +1455,8 @@ class SingleOpTest(ComputationTest):
b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677]) b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677])
c = self._NewComputation() c = self._NewComputation()
c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x))
expected = np.array([0.98923271, 0.48575411, 0.57952568, 0.12579775, expected = np.array(
0.96989155]) [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4) self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4)
@ -1974,7 +1975,7 @@ class ErrorTest(ComputationTest):
def TestFun(): def TestFun():
return c.Build().Compile(compile_options=options) return c.Build().Compile(compile_options=options)
self.assertRaisesRegexp( self.assertRaisesRegex(
RuntimeError, r".*Invalid argument shape.*" RuntimeError, r".*Invalid argument shape.*"
r"expected s32\[\], got f32\[\].*", TestFun) r"expected s32\[\], got f32\[\].*", TestFun)
@ -1988,7 +1989,7 @@ class ErrorTest(ComputationTest):
return xla_client.execute_with_python_values(c.Build().Compile(), return xla_client.execute_with_python_values(c.Build().Compile(),
[self.f32_scalar_2]) [self.f32_scalar_2])
self.assertRaisesRegexp( self.assertRaisesRegex(
RuntimeError, r"Invalid argument: Argument does not match.*" RuntimeError, r"Invalid argument: Argument does not match.*"
r"want s32\[\], got f32\[\].*", TestFun) r"want s32\[\], got f32\[\].*", TestFun)
@ -2031,5 +2032,47 @@ class SetShardingTest(ComputationTest):
np.testing.assert_allclose(ans, 4.14) np.testing.assert_allclose(ans, 4.14)
dlpack_dtypes = [
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64, np.float16, np.float32, np.float64, bfloat16
]
class DLPackTest(parameterized.TestCase):
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters({
"testcase_name":
"_{}[{}]".format(dtype.__name__, ",".join(map(str, shape))),
"dtype":
dtype,
"shape":
shape
} for dtype in dlpack_dtypes for shape in [(), (1,), (2, 3), (4, 1, 2)])
def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
backend = xla_client.get_local_backend()
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
del buffer # Free "buffer" to make sure dlt retains ownership.
self.assertEqual(type(dlt).__name__, "PyCapsule")
y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
np.testing.assert_array_equal(x, y.to_py())
def testTensorsCanBeConsumedOnceOnly(self):
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
backend = xla_client.get_local_backend()
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
def ConsumeDLPackTensor():
_ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
ConsumeDLPackTensor()
self.assertRaisesRegex(RuntimeError,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main() absltest.main()

View File

@ -26,6 +26,7 @@ load("//third_party/FXdiv:workspace.bzl", FXdiv = "repo")
load("//third_party/aws:workspace.bzl", aws = "repo") load("//third_party/aws:workspace.bzl", aws = "repo")
load("//third_party/clog:workspace.bzl", clog = "repo") load("//third_party/clog:workspace.bzl", clog = "repo")
load("//third_party/cpuinfo:workspace.bzl", cpuinfo = "repo") load("//third_party/cpuinfo:workspace.bzl", cpuinfo = "repo")
load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo") load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo")
load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo") load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo")
@ -48,6 +49,7 @@ def initialize_third_party():
aws() aws()
clog() clog()
cpuinfo() cpuinfo()
dlpack()
flatbuffers() flatbuffers()
hexagon_nn() hexagon_nn()
highwayhash() highwayhash()

14
third_party/dlpack/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,14 @@
# Description:
# DLPack is a protocol for sharing arrays between deep learning frameworks.
licenses(["notice"]) # Apache 2
exports_files(["LICENSE"])
cc_library(
name = "dlpack",
hdrs = [
"include/dlpack/dlpack.h",
],
visibility = ["//visibility:public"],
)

15
third_party/dlpack/workspace.bzl vendored Normal file
View File

@ -0,0 +1,15 @@
"""DLPack is a protocol for sharing arrays between deep learning frameworks."""
load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "dlpack",
strip_prefix = "dlpack-3efc489b55385936531a06ff83425b719387ec63",
sha256 = "b59586ce69bcf3efdbf3cf4803fadfeaae4948044e2b8d89cf912194cf28f233",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz",
"https://github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz",
],
build_file = "//third_party/dlpack:BUILD.bazel",
)