[PJRT] Move host literal cache out of PjRtBuffer::ToLiteral() and into the XLA Python bindings. Change ToLiteral() to have an asynchronous API that writes its output into a caller-provided buffer. Delete CopyToHostAsync() because it now serves no purpose.
Caching host transfers is a policy decision that PJRT should not be making on behalf of clients. Instead, clients can cache the transfer results if they want. The original motivation for the cache was the Python bindings; this change moves the cache into the Python bindings. This simplifies the PJRT API. PiperOrigin-RevId: 352858903 Change-Id: If17c69268e5f5c8690baa2f2ec88109376fc9c19
This commit is contained in:
parent
9e5c6794af
commit
c7e983d2c4
@ -132,8 +132,6 @@ cc_library(
|
||||
hdrs = ["pjrt_client.h"],
|
||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -141,13 +139,11 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -31,8 +31,8 @@ EventPool::Handle::~Handle() {
|
||||
EventPool::EventPool(bool allow_reuse)
|
||||
: allow_reuse_(allow_reuse), next_sequence_number_(0) {}
|
||||
|
||||
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
||||
se::Stream* stream) {
|
||||
StatusOr<EventPool::Handle> EventPool::AllocateEvent(
|
||||
se::StreamExecutor* executor) {
|
||||
Handle event;
|
||||
|
||||
if (allow_reuse_) {
|
||||
@ -44,15 +44,24 @@ StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
||||
}
|
||||
}
|
||||
if (!event.event_) {
|
||||
event.event_ = absl::make_unique<se::Event>(stream->parent());
|
||||
event.event_ = absl::make_unique<se::Event>(executor);
|
||||
TF_RET_CHECK(event.event_->Init()) << "Event initialization failed";
|
||||
}
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
stream->ThenRecordEvent(event.event_.get());
|
||||
event.sequence_number_ = next_sequence_number_++;
|
||||
}
|
||||
return event;
|
||||
}
|
||||
|
||||
void EventPool::ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
stream->ThenRecordEvent(handle.event_.get());
|
||||
handle.sequence_number_ = next_sequence_number_++;
|
||||
}
|
||||
|
||||
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
||||
se::Stream* stream) {
|
||||
TF_ASSIGN_OR_RETURN(EventPool::Handle handle,
|
||||
AllocateEvent(stream->parent()));
|
||||
ThenRecordEvent(stream, handle);
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -77,6 +77,11 @@ class EventPool {
|
||||
// cudaEventRecord.
|
||||
StatusOr<Handle> ThenAllocateAndRecordEvent(se::Stream* stream);
|
||||
|
||||
// Version of ThenAllocateAndRecordEvent split into two phases; this is
|
||||
// sometimes helpful if we want to avoid failures by preallocating events.
|
||||
StatusOr<Handle> AllocateEvent(se::StreamExecutor* executor);
|
||||
void ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle);
|
||||
|
||||
private:
|
||||
const bool allow_reuse_;
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
@ -284,28 +285,31 @@ class PjRtBuffer {
|
||||
virtual StatusOr<std::unique_ptr<ExternalReferenceHold>>
|
||||
AcquireExternalReference() = 0;
|
||||
|
||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||
// copies the buffer to the host. Blocks until the value is ready. If
|
||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
||||
// cached copy of the literal (i.e. The reference to the host value will be
|
||||
// removed.) If a layout is passed than a literal with this layout will be
|
||||
// returned.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
|
||||
return ToLiteral(/*discard_cached_copy=*/false, /*layout=*/{});
|
||||
}
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(bool discard_cached_copy) {
|
||||
return ToLiteral(discard_cached_copy, /*layout=*/{});
|
||||
}
|
||||
virtual StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) = 0;
|
||||
// Copies the buffer's value into `literal`. Calls `on_ready` when the value
|
||||
// (or an error) is ready. The transfer respects the layout of `literal`; to
|
||||
// specify a particular layout, set the layout before calling `ToLiteral`.
|
||||
virtual void ToLiteral(MutableLiteralBase* literal,
|
||||
std::function<void(Status)> on_ready) = 0;
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. The value can be retrieved by a later call to
|
||||
// ToLiteral(). If a layout is passed then a cached copy with this layout will
|
||||
// be created.
|
||||
Status CopyToHostAsync() { return CopyToHostAsync(/*layout=*/{}); }
|
||||
virtual Status CopyToHostAsync(absl::optional<xla::Layout> layout) = 0;
|
||||
// Synchronous overload of ToLiteral, as a convenience.
|
||||
Status ToLiteral(MutableLiteralBase* literal) {
|
||||
absl::Notification done;
|
||||
Status status;
|
||||
ToLiteral(literal, [&](Status s) {
|
||||
status = std::move(s);
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
// Convenience synchronous overload that allocates a literal with a default
|
||||
// layout.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
|
||||
auto literal = std::make_shared<Literal>(on_host_shape());
|
||||
TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
|
||||
return literal;
|
||||
}
|
||||
|
||||
// Drops the buffer's reference to its associated device memory, leaving the
|
||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||
|
@ -986,10 +986,9 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
|
||||
if (device_buffer_ == nullptr) {
|
||||
return std::shared_ptr<TrackedDeviceBuffer>();
|
||||
}
|
||||
// Clear host_values_ and set device_buffer_ to null now so that no other
|
||||
// Set device_buffer_ to null now so that no other
|
||||
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
|
||||
// below.
|
||||
host_values_.clear();
|
||||
std::swap(device_buffer_, device_buffer);
|
||||
WaitForOutstandingUsageHolds();
|
||||
// Now that all holds have completed and no more can be added, we can get
|
||||
@ -1126,7 +1125,6 @@ void PjRtStreamExecutorBuffer::ConfirmDonation(
|
||||
device_buffer->ReleaseDeviceMemory();
|
||||
// Make *this invalid so it can't be used again. Any threads blocking in
|
||||
// Release or GetBufferWithHold will see an invalid buffer and return.
|
||||
host_values_.clear();
|
||||
device_buffer_.reset();
|
||||
}
|
||||
// Unblock another thread, if any, trying to get a donation hold.
|
||||
@ -1147,84 +1145,47 @@ void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
|
||||
}
|
||||
}
|
||||
|
||||
Status PjRtStreamExecutorBuffer::CopyToHostAsync(
|
||||
absl::optional<xla::Layout> layout) {
|
||||
return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout)
|
||||
.status();
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtStreamExecutorBuffer::HostValue>>
|
||||
PjRtStreamExecutorBuffer::CopyToHostAsyncInternal(
|
||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
|
||||
std::function<void(Status)> on_ready) {
|
||||
if (IsEmptyTuple()) {
|
||||
return InvalidArgument("CopyToHostAsync called on empty tuple");
|
||||
on_ready(InvalidArgument("ToLiteral called on empty tuple"));
|
||||
return;
|
||||
}
|
||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
LocalDeviceState* local_device =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state();
|
||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||
const xla::Layout& host_layout =
|
||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
// We can't perform any other action while a donation hold is in progress.
|
||||
WaitForOutstandingDonationHold();
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument(
|
||||
"CopyToHostAsync() called on deleted or donated buffer");
|
||||
}
|
||||
if (discard_cached_copy) {
|
||||
auto it = host_values_.find(host_layout);
|
||||
if (it != host_values_.end()) {
|
||||
host_value = it->second;
|
||||
host_values_.erase(it);
|
||||
return host_value;
|
||||
} else {
|
||||
host_value = std::make_shared<HostValue>();
|
||||
}
|
||||
} else {
|
||||
std::shared_ptr<HostValue>& host_value_ref = host_values_[host_layout];
|
||||
if (host_value_ref) {
|
||||
return host_value_ref;
|
||||
}
|
||||
host_value = host_value_ref = std::make_shared<HostValue>();
|
||||
on_ready(InvalidArgument(
|
||||
"CopyToHostAsync() called on deleted or donated buffer"));
|
||||
return;
|
||||
}
|
||||
AcquireHoldLocked(&device_buffer);
|
||||
}
|
||||
|
||||
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
||||
Shape host_shape;
|
||||
if (layout.has_value()) {
|
||||
host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(),
|
||||
on_host_shape_.dimensions());
|
||||
*host_shape.mutable_layout() = host_layout;
|
||||
} else {
|
||||
host_shape = on_host_shape_;
|
||||
}
|
||||
host_value->value = std::make_shared<Literal>(host_shape);
|
||||
ShapedBuffer shaped_buffer =
|
||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
||||
device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_);
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
local_device->event_pool().AllocateEvent(stream->parent());
|
||||
if (!event_or.ok()) {
|
||||
on_ready(event_or.status());
|
||||
return;
|
||||
}
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->TransferLiteralFromDevice(stream, shaped_buffer,
|
||||
host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
host_value->status = done_status;
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
->TransferLiteralFromDevice(stream, shaped_buffer, literal,
|
||||
std::move(on_ready));
|
||||
|
||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
local_device->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||
if (!event_or.ok()) {
|
||||
// Allocating the event failed, so synchronize
|
||||
// the host on the copy and then drop the device buffer hold.
|
||||
StallStreamOnError(local_device, stream);
|
||||
return event_or.status();
|
||||
}
|
||||
local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
|
||||
usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
|
||||
// When using the ComputeSynchronized allocation model, retain a reference to
|
||||
// the device_buffer until the copy completes, to ensure that the buffer isn't
|
||||
@ -1238,20 +1199,6 @@ PjRtStreamExecutorBuffer::CopyToHostAsyncInternal(
|
||||
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
|
||||
stream,
|
||||
/*prefer_to_retain_reference=*/true);
|
||||
return host_value;
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtStreamExecutorBuffer::ToLiteral(
|
||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||
if (host_value == nullptr) {
|
||||
return InvalidArgument("ToLiteral called on deleted or donated buffer");
|
||||
}
|
||||
host_value->ready.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(host_value->status);
|
||||
return host_value->value;
|
||||
}
|
||||
|
||||
StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
|
||||
|
@ -497,11 +497,8 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
bool wait_for_operations_to_complete) override;
|
||||
|
||||
using PjRtBuffer::ToLiteral;
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
|
||||
|
||||
using PjRtBuffer::CopyToHostAsync;
|
||||
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
|
||||
void ToLiteral(MutableLiteralBase* literal,
|
||||
std::function<void(Status)> on_ready) override;
|
||||
|
||||
// Drops the buffer's reference to its associated device memory, leaving the
|
||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||
@ -558,16 +555,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
// The cached value of the buffer on the host, produced either from a call to
|
||||
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
||||
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
|
||||
struct HostValue {
|
||||
absl::Notification ready;
|
||||
// status and value are valid for reading only after `ready` has been
|
||||
// notified.
|
||||
Status status;
|
||||
std::shared_ptr<Literal> value;
|
||||
};
|
||||
|
||||
// Blocks in mu_.Await until there are no more usage holds.
|
||||
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
@ -598,14 +585,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
// successfully donated to an execution.
|
||||
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. A host value is returned and if
|
||||
// `discard_cached_copy` is false stored in an internal buffer so that future
|
||||
// transfers don't have to transfer the data from host again. If a layout is
|
||||
// passed then a literal of this layout will be returned and possibly cached.
|
||||
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
|
||||
bool discard_cached_copy, absl::optional<xla::Layout> layout);
|
||||
|
||||
// Drops a hold without taking any other action. Does a sanity check that
|
||||
// buffer==device_buffer_ or device_buffer_==nullptr.
|
||||
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
||||
@ -624,9 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
|
||||
mutable absl::Mutex mu_;
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
|
||||
TF_GUARDED_BY(mu_);
|
||||
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
||||
// Count of holds on the buffer.
|
||||
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
||||
// Semaphore used to ensure there is only one outstanding donation hold.
|
||||
|
@ -200,6 +200,7 @@ cc_library(
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core/platform:fingerprint",
|
||||
"//tensorflow/core/profiler:protos_all_cc",
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -85,6 +88,64 @@ Status PyBuffer::BlockHostUntilReady() {
|
||||
return buffer_->BlockHostUntilReady();
|
||||
}
|
||||
|
||||
Status PyBuffer::CopyToHostAsync() {
|
||||
if (!buffer_->IsOnCpu() && !host_value_) {
|
||||
host_value_ = std::make_shared<HostValue>();
|
||||
host_value_->value = std::make_shared<Literal>(buffer_->on_host_shape());
|
||||
buffer_->ToLiteral(host_value_->value.get(),
|
||||
[host_value{host_value_}](Status status) {
|
||||
host_value->status = std::move(status);
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
|
||||
if (buffer_->IsDeleted()) {
|
||||
return InvalidArgument("DeviceArray has been deleted.");
|
||||
}
|
||||
TF_RET_CHECK(buffer_->on_device_shape().IsArray());
|
||||
// On CPU, we can return the value in a zero-copy way.
|
||||
if (buffer_->IsOnCpu()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
py::dtype dtype,
|
||||
PrimitiveTypeToDtype(buffer_->on_host_shape().element_type()));
|
||||
// Objects that must be kept alive while the array is alive.
|
||||
struct Hold {
|
||||
py::object buffer;
|
||||
std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
|
||||
external_reference_hold;
|
||||
};
|
||||
auto hold = std::make_unique<Hold>();
|
||||
TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
|
||||
buffer_->AcquireExternalReference());
|
||||
hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
|
||||
void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
|
||||
py::capsule hold_capsule(hold.release(),
|
||||
[](void* h) { delete static_cast<Hold*>(h); });
|
||||
py::array array(dtype, buffer_->on_host_shape().dimensions(),
|
||||
ByteStridesForShape(buffer_->on_host_shape()), data,
|
||||
hold_capsule);
|
||||
array.attr("flags").attr("writeable") = Py_False;
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||
if (!host_value_->ready.HasBeenNotified()) {
|
||||
py::gil_scoped_release gil;
|
||||
host_value_->ready.WaitForNotification();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(host_value_->status);
|
||||
TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
|
||||
array.attr("flags").attr("writeable") = Py_False;
|
||||
return array;
|
||||
}
|
||||
|
||||
// TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
|
||||
StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
||||
if (buffer_->on_device_shape().IsTuple()) {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
@ -70,12 +71,12 @@ class PyBuffer : public DeviceArrayBase {
|
||||
|
||||
void Delete() {
|
||||
buffer_->Delete();
|
||||
npy_value_ = pybind11::none();
|
||||
host_value_ = nullptr;
|
||||
}
|
||||
|
||||
// Returns xla::InvalidArgument if the buffer has been deleted.
|
||||
Status BlockHostUntilReady();
|
||||
Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); }
|
||||
Status CopyToHostAsync();
|
||||
|
||||
const Shape& shape() { return buffer_->on_host_shape(); }
|
||||
|
||||
@ -102,8 +103,7 @@ class PyBuffer : public DeviceArrayBase {
|
||||
void SetStickyDevice(pybind11::object sticky_device);
|
||||
pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
|
||||
|
||||
void SetNpyValue(pybind11::object npy_value) { npy_value_ = npy_value; }
|
||||
pybind11::object GetNpyValue() const { return npy_value_; }
|
||||
StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj);
|
||||
|
||||
void SetAval(pybind11::object aval);
|
||||
pybind11::object GetAval() const { return aval_.value(); }
|
||||
@ -111,11 +111,16 @@ class PyBuffer : public DeviceArrayBase {
|
||||
private:
|
||||
friend class PyClient;
|
||||
|
||||
struct HostValue {
|
||||
absl::Notification ready;
|
||||
Status status;
|
||||
std::shared_ptr<xla::Literal> value;
|
||||
};
|
||||
std::shared_ptr<PyClient> client_;
|
||||
std::unique_ptr<PjRtBuffer> buffer_;
|
||||
std::shared_ptr<Traceback> traceback_;
|
||||
// The host numpy array caching the value when it has been copied to the host.
|
||||
pybind11::object npy_value_ = pybind11::none();
|
||||
std::shared_ptr<HostValue> host_value_; // Protected by the GIL.
|
||||
|
||||
absl::optional<pybind11::object> sticky_device_ = absl::nullopt;
|
||||
// TODO(jblespiau): It's currently there for convenience but maybe we can do
|
||||
// without it (adding `weak_type` instead).
|
||||
|
@ -68,25 +68,6 @@ bool IsOptimizedBuild() {
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
StatusOr<py::object> BufferToPython(PyBuffer* buffer, py::handle& buffer_obj) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
if (buffer->buffer()->IsOnCpu() &&
|
||||
buffer->buffer()->on_device_shape().IsArray() &&
|
||||
buffer->buffer()->on_device_shape().element_type() != BF16) {
|
||||
py::object out =
|
||||
py::reinterpret_steal<py::object>(PyArray_FROM_O(buffer_obj.ptr()));
|
||||
CHECK(out.ptr() != nullptr) << buffer->buffer()->on_host_shape().ToString(
|
||||
/*print_layout=*/true);
|
||||
return out;
|
||||
}
|
||||
std::shared_ptr<Literal> literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral());
|
||||
}
|
||||
return LiteralToPython(std::move(literal));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(xla_extension, m) {
|
||||
@ -328,20 +309,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_property_readonly("ndim", &PyBuffer::ndim)
|
||||
.def_property_readonly(
|
||||
"_value",
|
||||
[](py::handle buffer_obj) -> pybind11::object {
|
||||
[](py::handle buffer_obj) -> StatusOr<pybind11::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
||||
if (buffer->is_deleted()) {
|
||||
throw std::runtime_error("DeviceArray has been deleted.");
|
||||
}
|
||||
py::object npy_value_ = buffer->GetNpyValue();
|
||||
if (npy_value_.is_none()) {
|
||||
npy_value_ = BufferToPython(buffer, buffer_obj).ValueOrDie();
|
||||
// TODO(jblspiau): Change `LiteralToPython` to return a
|
||||
// `py::array`, so we can set more easily the attribute.
|
||||
npy_value_.attr("flags").attr("writeable") = Py_False;
|
||||
buffer->SetNpyValue(npy_value_);
|
||||
}
|
||||
return npy_value_;
|
||||
return buffer->AsNumPyArray(buffer_obj);
|
||||
})
|
||||
.def("copy_to_device", &PyBuffer::CopyToDevice)
|
||||
.def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
|
||||
@ -359,7 +330,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("to_py",
|
||||
[](py::handle buffer_obj) {
|
||||
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
||||
return BufferToPython(buffer, buffer_obj);
|
||||
return buffer->AsNumPyArray(buffer_obj);
|
||||
})
|
||||
.def("xla_shape", &PyBuffer::shape)
|
||||
.def_property_readonly("client", &PyBuffer::client)
|
||||
|
Loading…
Reference in New Issue
Block a user