diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a596f68f937..5a0a516e930 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -34,6 +34,7 @@ py_test( ":xla_client", ":xla_extension", "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", ] + xla_py_test_deps(), ) @@ -248,6 +249,34 @@ py_test( ] + xla_py_test_deps(), ) +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":local_client", + ":shared_device_buffer", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:platform", + "//tensorflow/stream_executor/cuda:cuda_platform_id", + "//tensorflow/stream_executor/host:host_platform_id", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@pybind11", + ], +) + config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -266,6 +295,7 @@ pybind_extension( module_name = "xla_extension", deps = [ ":bfloat16", + ":dlpack", ":local_client", ":shared_device_buffer", ":python_ref_manager", diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc new file mode 100644 index 00000000000..a7d4e9bf02a --- /dev/null +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -0,0 +1,339 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/dlpack.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" // TF:dlpack +#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/host/host_platform_id.h" +#include "tensorflow/stream_executor/platform.h" + +namespace py = pybind11; + +namespace xla { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + std::shared_ptr buffer; + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case PRED: + return DLDataType{kDLInt, 1, 1}; + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case C64: + case C128: + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLInt: + switch (type.bits) { + case 1: + return PRED; + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 1: + return PRED; + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLBfloat: + switch (type.bits) { + case 16: + return BF16; + default: + return Unimplemented( + "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +// Returns the strides for `shape`. +std::vector StridesForShape(const Shape& shape) { + std::vector strides; + CHECK(shape.IsArray()); + CHECK(shape.has_layout()); + + strides.resize(shape.dimensions_size()); + int64 stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + for (int i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= shape.dimensions(i); + } + return strides; +} + +StatusOr> StridesToLayout(absl::Span dims, + absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, + [&](int a, int b) { return strides[a] < strides[b]; }); + int64 stride = 1; + for (int64 d : minor_to_major) { + if (strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +StatusOr DLDeviceTypeForDevice(const Device& device) { + const se::Platform* platform = + device.local_device_state()->executor()->platform(); + if (platform->id() == se::host::kHostPlatformId) { + return kDLCPU; + } else if (platform->id() == se::cuda::kCudaPlatformId) { + return kDLGPU; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +StatusOr DLContextForDevice(const Device& device) { + DLContext context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_device_state()->device_ordinal(); + return context; +} + +StatusOr> DeviceForDLContext( + const PyLocalClient& client, const DLContext& context) { + se::Platform::Id platform_id; + switch (context.device_type) { + case kDLCPU: + platform_id = se::host::kHostPlatformId; + break; + case kDLGPU: + platform_id = se::cuda::kCudaPlatformId; + break; + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } + auto it = absl::c_find_if( + client.local_devices(), [&](const std::shared_ptr& device) { + return device->local_device_state()->executor()->platform()->id() == + platform_id && + device->local_device_state()->device_ordinal() == + context.device_id; + }); + if (it == client.local_devices().end()) { + return InvalidArgument( + "No matching device found for DLPack device_type %d device_id %d", + context.device_type, context.device_id); + } + return *it; +} + +} // namespace + +StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { + auto pack = absl::make_unique(); + pack->buffer = buffer->DeviceBuffer(); + if (!pack->buffer) { + return InvalidArgument( + "Cannot convert deleted/invalid buffer to DLPack tensor."); + } + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + DLTensor& dt = pack->tensor.dl_tensor; + if (buffer->on_device_shape().IsTuple()) { + return Unimplemented( + "unsafe_buffer_pointer is not implemented for tuple " + "buffers."); + } + TF_RET_CHECK(pack->buffer->device_memory().size() == 1); + dt.data = pack->buffer->device_memory().front().opaque(); + TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device())); + dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal(); + dt.ndim = buffer->on_host_shape().dimensions_size(); + TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( + buffer->on_host_shape().element_type())); + + pack->shape = std::vector(buffer->on_host_shape().dimensions().begin(), + buffer->on_host_shape().dimensions().end()); + pack->strides = StridesForShape(buffer->on_host_shape()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) { + DLPackTensorDeleter(static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName))); + }); + + TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady()); + return capsule; +} + +StatusOr> DLPackManagedTensorToBuffer( + const pybind11::capsule& tensor, std::shared_ptr client) { + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(std::shared_ptr device, + DeviceForDLContext(*client, dlmt->dl_tensor.ctx)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = + ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major); + se::DeviceMemoryBase buffer( + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, + ShapeUtil::ByteSizeOf(shape)); + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + auto device_buffer = std::make_shared( + /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id, + std::initializer_list{buffer}, + /*children=*/std::vector>{}, + /*definition_event=*/nullptr, std::move(on_delete_callback)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + return absl::make_unique(shape, shape, + std::move(device_buffer), + std::move(client), std::move(device)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h new file mode 100644 index 00000000000..92eba687225 --- /dev/null +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ + +#include "include/pybind11/pybind11.h" +#include "tensorflow/compiler/xla/python/local_client.h" + +namespace xla { + +StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer); + +StatusOr> DLPackManagedTensorToBuffer( + const pybind11::capsule& tensor, std::shared_ptr client); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index c9fe33799fa..001cf187bdd 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -141,8 +141,10 @@ class PyLocalClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() { return devices_; } - const std::vector>& local_devices() { + const std::vector>& devices() const { + return devices_; + } + const std::vector>& local_devices() const { return local_devices_; } const std::map>& id_to_device() const { diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/python/local_device_state.h index 7348b9c59f0..6d228f4a2b6 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/python/local_device_state.h @@ -44,6 +44,7 @@ class LocalDeviceState { bool asynchronous, bool allow_event_reuse); virtual ~LocalDeviceState(); + se::StreamExecutor* executor() const { return executor_; } // StreamExecutor (local) device ordinal. int device_ordinal() const { return executor_->device_ordinal(); } diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc index c788b364f55..e1f00432d37 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -122,7 +122,8 @@ SharedDeviceBuffer::MakeTuple( return std::make_shared( allocator, device_ordinal, std::initializer_list{device_memory.Release()}, - std::move(children), std::move(definition_event)); + std::move(children), std::move(definition_event), + /*on_delete_callback=*/nullptr); } /* static */ StatusOr> @@ -179,12 +180,14 @@ SharedDeviceBuffer::SharedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event) + std::shared_ptr definition_event, + std::function on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), device_memory_(device_memory.begin(), device_memory.end()), children_(std::move(children)), - definition_event_(std::move(definition_event)) {} + definition_event_(std::move(definition_event)), + on_delete_callback_(std::move(on_delete_callback)) {} SharedDeviceBuffer::SharedDeviceBuffer( absl::Span device_memory, @@ -211,6 +214,9 @@ SharedDeviceBuffer::~SharedDeviceBuffer() { } } } + if (on_delete_callback_) { + on_delete_callback_(); + } } void GetDeviceBufferDefinitionEvents( diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h index 65d1518f46c..8d9d8278d33 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -120,6 +120,9 @@ class SharedDeviceBuffer { } se::DeviceMemoryAllocator* allocator() const { return allocator_; } int device_ordinal() const { return device_ordinal_; } + absl::InlinedVector& device_memory() { + return device_memory_; + } const absl::InlinedVector& device_memory() const { return device_memory_; } @@ -131,7 +134,8 @@ class SharedDeviceBuffer { SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event); + std::shared_ptr definition_event, + std::function on_delete_callback); SharedDeviceBuffer(absl::Span device_memory, std::vector> children, std::shared_ptr definition_event); @@ -152,6 +156,9 @@ class SharedDeviceBuffer { // single-stream execution case where events are not necessary for buffer // event sequencing. std::shared_ptr definition_event_; + + // A callback to call when the SharedDeviceBuffer is about to be destroyed. + std::function on_delete_callback_; }; // Populates 'events' with the set of buffer definition events for all buffers diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f6017397c2e..d83b2d97550 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/python/bfloat16.h" +#include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" @@ -652,6 +653,9 @@ PYBIND11_MODULE(xla_extension, m) { .def("SetSharding", &XlaBuilder::SetSharding) .def("ClearSharding", &XlaBuilder::ClearSharding); + m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); + m.def("DLPackManagedTensorToBuffer", DLPackManagedTensorToBuffer); + // ops submodule, containing free functions that add operators to an // XlaBuilder. py::module ops = m.def_submodule("ops", "XLA operations"); diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 0fd0813bdcb..05a64dd0f76 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1,3 +1,4 @@ +# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,12 +24,12 @@ import itertools import threading from absl.testing import absltest +from absl.testing import parameterized import numpy as np from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client - bfloat16 = xla_client.bfloat16 @@ -1420,24 +1421,24 @@ class SingleOpTest(ComputationTest): # FFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) - self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)), - rtol=1e-4) + self._ExecuteAndCompareClose( + c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4) # IFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) - self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), - rtol=1e-4) + self._ExecuteAndCompareClose( + c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4) # RFFT b = rng.randn(*shape).astype(np.float32) c = self._NewComputation() c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) - self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), - rtol=1e-4) + self._ExecuteAndCompareClose( + c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4) # IRFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) - self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), - rtol=1e-4) + self._ExecuteAndCompareClose( + c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4) def testNextAfter(self): c = self._NewComputation() @@ -1454,8 +1455,8 @@ class SingleOpTest(ComputationTest): b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677]) c = self._NewComputation() c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) - expected = np.array([0.98923271, 0.48575411, 0.57952568, 0.12579775, - 0.96989155]) + expected = np.array( + [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4) @@ -1974,7 +1975,7 @@ class ErrorTest(ComputationTest): def TestFun(): return c.Build().Compile(compile_options=options) - self.assertRaisesRegexp( + self.assertRaisesRegex( RuntimeError, r".*Invalid argument shape.*" r"expected s32\[\], got f32\[\].*", TestFun) @@ -1988,7 +1989,7 @@ class ErrorTest(ComputationTest): return xla_client.execute_with_python_values(c.Build().Compile(), [self.f32_scalar_2]) - self.assertRaisesRegexp( + self.assertRaisesRegex( RuntimeError, r"Invalid argument: Argument does not match.*" r"want s32\[\], got f32\[\].*", TestFun) @@ -2031,5 +2032,47 @@ class SetShardingTest(ComputationTest): np.testing.assert_allclose(ans, 4.14) +dlpack_dtypes = [ + np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, + np.uint64, np.float16, np.float32, np.float64, bfloat16 +] + + +class DLPackTest(parameterized.TestCase): + + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": + "_{}[{}]".format(dtype.__name__, ",".join(map(str, shape))), + "dtype": + dtype, + "shape": + shape + } for dtype in dlpack_dtypes for shape in [(), (1,), (2, 3), (4, 1, 2)]) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + backend = xla_client.get_local_backend() + buffer = xla_client.Buffer.from_pyval(x, backend=backend) + dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) + del buffer # Free "buffer" to make sure dlt retains ownership. + self.assertEqual(type(dlt).__name__, "PyCapsule") + y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) + np.testing.assert_array_equal(x, y.to_py()) + + def testTensorsCanBeConsumedOnceOnly(self): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + backend = xla_client.get_local_backend() + buffer = xla_client.Buffer.from_pyval(x, backend=backend) + dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) + + def ConsumeDLPackTensor(): + _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) + + ConsumeDLPackTensor() + self.assertRaisesRegex(RuntimeError, + ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + if __name__ == "__main__": absltest.main() diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d43df54a6ae..b71a298bada 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -26,6 +26,7 @@ load("//third_party/FXdiv:workspace.bzl", FXdiv = "repo") load("//third_party/aws:workspace.bzl", aws = "repo") load("//third_party/clog:workspace.bzl", clog = "repo") load("//third_party/cpuinfo:workspace.bzl", cpuinfo = "repo") +load("//third_party/dlpack:workspace.bzl", dlpack = "repo") load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo") load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo") @@ -48,6 +49,7 @@ def initialize_third_party(): aws() clog() cpuinfo() + dlpack() flatbuffers() hexagon_nn() highwayhash() diff --git a/third_party/dlpack/BUILD.bazel b/third_party/dlpack/BUILD.bazel new file mode 100644 index 00000000000..cd52d710ebe --- /dev/null +++ b/third_party/dlpack/BUILD.bazel @@ -0,0 +1,14 @@ +# Description: +# DLPack is a protocol for sharing arrays between deep learning frameworks. + +licenses(["notice"]) # Apache 2 + +exports_files(["LICENSE"]) + +cc_library( + name = "dlpack", + hdrs = [ + "include/dlpack/dlpack.h", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/dlpack/workspace.bzl b/third_party/dlpack/workspace.bzl new file mode 100644 index 00000000000..f82e88b129e --- /dev/null +++ b/third_party/dlpack/workspace.bzl @@ -0,0 +1,15 @@ +"""DLPack is a protocol for sharing arrays between deep learning frameworks.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "dlpack", + strip_prefix = "dlpack-3efc489b55385936531a06ff83425b719387ec63", + sha256 = "b59586ce69bcf3efdbf3cf4803fadfeaae4948044e2b8d89cf912194cf28f233", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz", + "https://github.com/dmlc/dlpack/archive/3efc489b55385936531a06ff83425b719387ec63.tar.gz", + ], + build_file = "//third_party/dlpack:BUILD.bazel", + )