From c7e983d2c43ddc355e565a8a98d4c2dc7540e771 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 20 Jan 2021 13:10:55 -0800 Subject: [PATCH] [PJRT] Move host literal cache out of PjRtBuffer::ToLiteral() and into the XLA Python bindings. Change ToLiteral() to have an asynchronous API that writes its output into a caller-provided buffer. Delete CopyToHostAsync() because it now serves no purpose. Caching host transfers is a policy decision that PJRT should not be making on behalf of clients. Instead, clients can cache the transfer results if they want. The original motivation for the cache was the Python bindings; this change moves the cache into the Python bindings. This simplifies the PJRT API. PiperOrigin-RevId: 352858903 Change-Id: If17c69268e5f5c8690baa2f2ec88109376fc9c19 --- tensorflow/compiler/xla/pjrt/BUILD | 6 +- tensorflow/compiler/xla/pjrt/event_pool.cc | 25 +++-- tensorflow/compiler/xla/pjrt/event_pool.h | 5 + tensorflow/compiler/xla/pjrt/pjrt_client.h | 46 ++++----- .../xla/pjrt/pjrt_stream_executor_client.cc | 93 ++++--------------- .../xla/pjrt/pjrt_stream_executor_client.h | 28 +----- tensorflow/compiler/xla/python/BUILD | 1 + tensorflow/compiler/xla/python/py_buffer.cc | 61 ++++++++++++ tensorflow/compiler/xla/python/py_buffer.h | 17 ++-- tensorflow/compiler/xla/python/xla.cc | 37 +------- 10 files changed, 147 insertions(+), 172 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index ed2f0bcf1f4..705d8b2f56b 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -132,8 +132,6 @@ cc_library( hdrs = ["pjrt_client.h"], visibility = ["//tensorflow/compiler/xla:friends"], deps = [ - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -141,13 +139,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/core:lib", - "@com_google_absl//absl/base", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/pjrt/event_pool.cc b/tensorflow/compiler/xla/pjrt/event_pool.cc index 86aa38cdd0f..292082d23ae 100644 --- a/tensorflow/compiler/xla/pjrt/event_pool.cc +++ b/tensorflow/compiler/xla/pjrt/event_pool.cc @@ -31,8 +31,8 @@ EventPool::Handle::~Handle() { EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse), next_sequence_number_(0) {} -StatusOr EventPool::ThenAllocateAndRecordEvent( - se::Stream* stream) { +StatusOr EventPool::AllocateEvent( + se::StreamExecutor* executor) { Handle event; if (allow_reuse_) { @@ -44,15 +44,24 @@ StatusOr EventPool::ThenAllocateAndRecordEvent( } } if (!event.event_) { - event.event_ = absl::make_unique(stream->parent()); + event.event_ = absl::make_unique(executor); TF_RET_CHECK(event.event_->Init()) << "Event initialization failed"; } - { - absl::MutexLock lock(&mu_); - stream->ThenRecordEvent(event.event_.get()); - event.sequence_number_ = next_sequence_number_++; - } return event; } +void EventPool::ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle) { + absl::MutexLock lock(&mu_); + stream->ThenRecordEvent(handle.event_.get()); + handle.sequence_number_ = next_sequence_number_++; +} + +StatusOr EventPool::ThenAllocateAndRecordEvent( + se::Stream* stream) { + TF_ASSIGN_OR_RETURN(EventPool::Handle handle, + AllocateEvent(stream->parent())); + ThenRecordEvent(stream, handle); + return handle; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/event_pool.h b/tensorflow/compiler/xla/pjrt/event_pool.h index 47768c28fd9..c624763e65d 100644 --- a/tensorflow/compiler/xla/pjrt/event_pool.h +++ b/tensorflow/compiler/xla/pjrt/event_pool.h @@ -77,6 +77,11 @@ class EventPool { // cudaEventRecord. StatusOr ThenAllocateAndRecordEvent(se::Stream* stream); + // Version of ThenAllocateAndRecordEvent split into two phases; this is + // sometimes helpful if we want to avoid failures by preallocating events. + StatusOr AllocateEvent(se::StreamExecutor* executor); + void ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle); + private: const bool allow_reuse_; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index b54c93ba214..fd1715e53e6 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/synchronization/notification.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" @@ -284,28 +285,31 @@ class PjRtBuffer { virtual StatusOr> AcquireExternalReference() = 0; - // Returns the buffer's value as an XLA Literal. If the value has previously - // been prefetched to the host, then returns the prefetched version, otherwise - // copies the buffer to the host. Blocks until the value is ready. If - // `discard_cached_copy` is true then buffer will no longer keep hold of a - // cached copy of the literal (i.e. The reference to the host value will be - // removed.) If a layout is passed than a literal with this layout will be - // returned. - StatusOr> ToLiteral() { - return ToLiteral(/*discard_cached_copy=*/false, /*layout=*/{}); - } - StatusOr> ToLiteral(bool discard_cached_copy) { - return ToLiteral(discard_cached_copy, /*layout=*/{}); - } - virtual StatusOr> ToLiteral( - bool discard_cached_copy, absl::optional layout) = 0; + // Copies the buffer's value into `literal`. Calls `on_ready` when the value + // (or an error) is ready. The transfer respects the layout of `literal`; to + // specify a particular layout, set the layout before calling `ToLiteral`. + virtual void ToLiteral(MutableLiteralBase* literal, + std::function on_ready) = 0; - // Initiates a copy of the buffer to the host. Does not block waiting for - // the transfer to complete. The value can be retrieved by a later call to - // ToLiteral(). If a layout is passed then a cached copy with this layout will - // be created. - Status CopyToHostAsync() { return CopyToHostAsync(/*layout=*/{}); } - virtual Status CopyToHostAsync(absl::optional layout) = 0; + // Synchronous overload of ToLiteral, as a convenience. + Status ToLiteral(MutableLiteralBase* literal) { + absl::Notification done; + Status status; + ToLiteral(literal, [&](Status s) { + status = std::move(s); + done.Notify(); + }); + done.WaitForNotification(); + return status; + } + + // Convenience synchronous overload that allocates a literal with a default + // layout. + StatusOr> ToLiteral() { + auto literal = std::make_shared(on_host_shape()); + TF_RETURN_IF_ERROR(ToLiteral(literal.get())); + return literal; + } // Drops the buffer's reference to its associated device memory, leaving the // buffer in an invalid state. The memory will be freed lazily when all async diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index b766c581f7a..1852af3ab50 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -986,10 +986,9 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { if (device_buffer_ == nullptr) { return std::shared_ptr(); } - // Clear host_values_ and set device_buffer_ to null now so that no other + // Set device_buffer_ to null now so that no other // thread can add a hold while we are in WaitForOutstandingUsageHolds() // below. - host_values_.clear(); std::swap(device_buffer_, device_buffer); WaitForOutstandingUsageHolds(); // Now that all holds have completed and no more can be added, we can get @@ -1126,7 +1125,6 @@ void PjRtStreamExecutorBuffer::ConfirmDonation( device_buffer->ReleaseDeviceMemory(); // Make *this invalid so it can't be used again. Any threads blocking in // Release or GetBufferWithHold will see an invalid buffer and return. - host_values_.clear(); device_buffer_.reset(); } // Unblock another thread, if any, trying to get a donation hold. @@ -1147,84 +1145,47 @@ void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type, } } -Status PjRtStreamExecutorBuffer::CopyToHostAsync( - absl::optional layout) { - return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout) - .status(); -} - -StatusOr> -PjRtStreamExecutorBuffer::CopyToHostAsyncInternal( - bool discard_cached_copy, absl::optional layout) { +void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal, + std::function on_ready) { if (IsEmptyTuple()) { - return InvalidArgument("CopyToHostAsync called on empty tuple"); + on_ready(InvalidArgument("ToLiteral called on empty tuple")); + return; } - ScopedHold device_buffer(this, ScopedHold::kUsage); - std::shared_ptr host_value; LocalDeviceState* local_device = tensorflow::down_cast(device_) ->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); - const xla::Layout& host_layout = - layout.has_value() ? layout.value() : on_host_shape_.layout(); + ScopedHold device_buffer(this, ScopedHold::kUsage); { absl::MutexLock lock(&mu_); // We can't perform any other action while a donation hold is in progress. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return InvalidArgument( - "CopyToHostAsync() called on deleted or donated buffer"); - } - if (discard_cached_copy) { - auto it = host_values_.find(host_layout); - if (it != host_values_.end()) { - host_value = it->second; - host_values_.erase(it); - return host_value; - } else { - host_value = std::make_shared(); - } - } else { - std::shared_ptr& host_value_ref = host_values_[host_layout]; - if (host_value_ref) { - return host_value_ref; - } - host_value = host_value_ref = std::make_shared(); + on_ready(InvalidArgument( + "CopyToHostAsync() called on deleted or donated buffer")); + return; } AcquireHoldLocked(&device_buffer); } + WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - Shape host_shape; - if (layout.has_value()) { - host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(), - on_host_shape_.dimensions()); - *host_shape.mutable_layout() = host_layout; - } else { - host_shape = on_host_shape_; - } - host_value->value = std::make_shared(host_shape); ShapedBuffer shaped_buffer = - device_buffer->AsShapedBuffer(host_shape, on_device_shape_); + device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_); + StatusOr event_or = + local_device->event_pool().AllocateEvent(stream->parent()); + if (!event_or.ok()) { + on_ready(event_or.status()); + return; + } tensorflow::down_cast(client_) ->client() ->backend() .transfer_manager() - ->TransferLiteralFromDevice(stream, shaped_buffer, - host_value->value.get(), - [host_value](Status done_status) { - host_value->status = done_status; - host_value->ready.Notify(); - }); + ->TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(on_ready)); auto usage_event = std::make_shared(); - StatusOr event_or = - local_device->event_pool().ThenAllocateAndRecordEvent(stream); - if (!event_or.ok()) { - // Allocating the event failed, so synchronize - // the host on the copy and then drop the device buffer hold. - StallStreamOnError(local_device, stream); - return event_or.status(); - } + local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie()); usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); // When using the ComputeSynchronized allocation model, retain a reference to // the device_buffer until the copy completes, to ensure that the buffer isn't @@ -1238,20 +1199,6 @@ PjRtStreamExecutorBuffer::CopyToHostAsyncInternal( RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, stream, /*prefer_to_retain_reference=*/true); - return host_value; -} - -StatusOr> PjRtStreamExecutorBuffer::ToLiteral( - const bool discard_cached_copy, absl::optional layout) { - tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral"); - TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, - CopyToHostAsyncInternal(discard_cached_copy, layout)); - if (host_value == nullptr) { - return InvalidArgument("ToLiteral called on deleted or donated buffer"); - } - host_value->ready.WaitForNotification(); - TF_RETURN_IF_ERROR(host_value->status); - return host_value->value; } StatusOr PjRtStreamExecutorBuffer::AsShapedBuffer() const { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 2f55a71a564..9a1d64a6c96 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -497,11 +497,8 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { bool wait_for_operations_to_complete) override; using PjRtBuffer::ToLiteral; - StatusOr> ToLiteral( - bool discard_cached_copy, absl::optional layout) override; - - using PjRtBuffer::CopyToHostAsync; - Status CopyToHostAsync(absl::optional layout) override; + void ToLiteral(MutableLiteralBase* literal, + std::function on_ready) override; // Drops the buffer's reference to its associated device memory, leaving the // buffer in an invalid state. The memory will be freed lazily when all async @@ -558,16 +555,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { private: friend class PjRtClient; - // The cached value of the buffer on the host, produced either from a call to - // CopyToHost or from a call to ToLiteral. Once a value has been fetched to - // the host, it persists Delete() is called or the PjRtBuffer is destroyed. - struct HostValue { - absl::Notification ready; - // status and value are valid for reading only after `ready` has been - // notified. - Status status; - std::shared_ptr value; - }; // Blocks in mu_.Await until there are no more usage holds. void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -598,14 +585,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // successfully donated to an execution. void ConfirmDonation(TrackedDeviceBuffer* device_buffer); - // Initiates a copy of the buffer to the host. Does not block waiting for - // the transfer to complete. A host value is returned and if - // `discard_cached_copy` is false stored in an internal buffer so that future - // transfers don't have to transfer the data from host again. If a layout is - // passed then a literal of this layout will be returned and possibly cached. - StatusOr> CopyToHostAsyncInternal( - bool discard_cached_copy, absl::optional layout); - // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); @@ -624,9 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { mutable absl::Mutex mu_; std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); - absl::flat_hash_map> host_values_ - TF_GUARDED_BY(mu_); - std::shared_ptr host_value_ TF_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ TF_GUARDED_BY(mu_); // Semaphore used to ensure there is only one outstanding donation hold. diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 1414ea5aa1f..3111e4607c5 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -200,6 +200,7 @@ cc_library( ":types", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core/platform:fingerprint", "//tensorflow/core/profiler:protos_all_cc", diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 615f9708194..127521d523e 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -16,9 +16,12 @@ limitations under the License. #include "tensorflow/compiler/xla/python/py_buffer.h" #include "absl/base/casts.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -85,6 +88,64 @@ Status PyBuffer::BlockHostUntilReady() { return buffer_->BlockHostUntilReady(); } +Status PyBuffer::CopyToHostAsync() { + if (!buffer_->IsOnCpu() && !host_value_) { + host_value_ = std::make_shared(); + host_value_->value = std::make_shared(buffer_->on_host_shape()); + buffer_->ToLiteral(host_value_->value.get(), + [host_value{host_value_}](Status status) { + host_value->status = std::move(status); + host_value->ready.Notify(); + }); + } + return Status::OK(); +} + +StatusOr PyBuffer::AsNumPyArray(py::handle this_obj) { + if (buffer_->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + TF_RET_CHECK(buffer_->on_device_shape().IsArray()); + // On CPU, we can return the value in a zero-copy way. + if (buffer_->IsOnCpu()) { + TF_ASSIGN_OR_RETURN( + py::dtype dtype, + PrimitiveTypeToDtype(buffer_->on_host_shape().element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + py::object buffer; + std::unique_ptr + external_reference_hold; + }; + auto hold = std::make_unique(); + TF_ASSIGN_OR_RETURN(hold->external_reference_hold, + buffer_->AcquireExternalReference()); + hold->buffer = py::reinterpret_borrow(this_obj); + void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + py::capsule hold_capsule(hold.release(), + [](void* h) { delete static_cast(h); }); + py::array array(dtype, buffer_->on_host_shape().dimensions(), + ByteStridesForShape(buffer_->on_host_shape()), data, + hold_capsule); + array.attr("flags").attr("writeable") = Py_False; + { + py::gil_scoped_release gil; + TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady()); + } + return array; + } + + TF_RETURN_IF_ERROR(CopyToHostAsync()); + if (!host_value_->ready.HasBeenNotified()) { + py::gil_scoped_release gil; + host_value_->ready.WaitForNotification(); + } + TF_RETURN_IF_ERROR(host_value_->status); + TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value)); + array.attr("flags").attr("writeable") = Py_False; + return array; +} + // TODO(zhangqiaorjc): Delete UnsafeBufferPointer. StatusOr PyBuffer::UnsafeBufferPointer() const { if (buffer_->on_device_shape().IsTuple()) { diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index d2b584cf147..374875702ba 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/synchronization/notification.h" #include "absl/types/optional.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -70,12 +71,12 @@ class PyBuffer : public DeviceArrayBase { void Delete() { buffer_->Delete(); - npy_value_ = pybind11::none(); + host_value_ = nullptr; } // Returns xla::InvalidArgument if the buffer has been deleted. Status BlockHostUntilReady(); - Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); } + Status CopyToHostAsync(); const Shape& shape() { return buffer_->on_host_shape(); } @@ -102,8 +103,7 @@ class PyBuffer : public DeviceArrayBase { void SetStickyDevice(pybind11::object sticky_device); pybind11::object GetStickyDevice() const { return sticky_device_.value(); } - void SetNpyValue(pybind11::object npy_value) { npy_value_ = npy_value; } - pybind11::object GetNpyValue() const { return npy_value_; } + StatusOr AsNumPyArray(pybind11::handle this_obj); void SetAval(pybind11::object aval); pybind11::object GetAval() const { return aval_.value(); } @@ -111,11 +111,16 @@ class PyBuffer : public DeviceArrayBase { private: friend class PyClient; + struct HostValue { + absl::Notification ready; + Status status; + std::shared_ptr value; + }; std::shared_ptr client_; std::unique_ptr buffer_; std::shared_ptr traceback_; - // The host numpy array caching the value when it has been copied to the host. - pybind11::object npy_value_ = pybind11::none(); + std::shared_ptr host_value_; // Protected by the GIL. + absl::optional sticky_device_ = absl::nullopt; // TODO(jblespiau): It's currently there for convenience but maybe we can do // without it (adding `weak_type` instead). diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index a6e11308540..de343bc1ef3 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -68,25 +68,6 @@ bool IsOptimizedBuild() { #endif // NDEBUG } -StatusOr BufferToPython(PyBuffer* buffer, py::handle& buffer_obj) { - GlobalPyRefManager()->CollectGarbage(); - if (buffer->buffer()->IsOnCpu() && - buffer->buffer()->on_device_shape().IsArray() && - buffer->buffer()->on_device_shape().element_type() != BF16) { - py::object out = - py::reinterpret_steal(PyArray_FROM_O(buffer_obj.ptr())); - CHECK(out.ptr() != nullptr) << buffer->buffer()->on_host_shape().ToString( - /*print_layout=*/true); - return out; - } - std::shared_ptr literal; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral()); - } - return LiteralToPython(std::move(literal)); -} - } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -328,20 +309,10 @@ PYBIND11_MODULE(xla_extension, m) { .def_property_readonly("ndim", &PyBuffer::ndim) .def_property_readonly( "_value", - [](py::handle buffer_obj) -> pybind11::object { + [](py::handle buffer_obj) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); PyBuffer* buffer = buffer_obj.cast(); - if (buffer->is_deleted()) { - throw std::runtime_error("DeviceArray has been deleted."); - } - py::object npy_value_ = buffer->GetNpyValue(); - if (npy_value_.is_none()) { - npy_value_ = BufferToPython(buffer, buffer_obj).ValueOrDie(); - // TODO(jblspiau): Change `LiteralToPython` to return a - // `py::array`, so we can set more easily the attribute. - npy_value_.attr("flags").attr("writeable") = Py_False; - buffer->SetNpyValue(npy_value_); - } - return npy_value_; + return buffer->AsNumPyArray(buffer_obj); }) .def("copy_to_device", &PyBuffer::CopyToDevice) .def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes) @@ -359,7 +330,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("to_py", [](py::handle buffer_obj) { PyBuffer* buffer = buffer_obj.cast(); - return BufferToPython(buffer, buffer_obj); + return buffer->AsNumPyArray(buffer_obj); }) .def("xla_shape", &PyBuffer::shape) .def_property_readonly("client", &PyBuffer::client)