[PJRT] Move host literal cache out of PjRtBuffer::ToLiteral() and into the XLA Python bindings. Change ToLiteral() to have an asynchronous API that writes its output into a caller-provided buffer. Delete CopyToHostAsync() because it now serves no purpose.
Caching host transfers is a policy decision that PJRT should not be making on behalf of clients. Instead, clients can cache the transfer results if they want. The original motivation for the cache was the Python bindings; this change moves the cache into the Python bindings. This simplifies the PJRT API. PiperOrigin-RevId: 352858903 Change-Id: If17c69268e5f5c8690baa2f2ec88109376fc9c19
This commit is contained in:
parent
9e5c6794af
commit
c7e983d2c4
@ -132,8 +132,6 @@ cc_library(
|
|||||||
hdrs = ["pjrt_client.h"],
|
hdrs = ["pjrt_client.h"],
|
||||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -141,13 +139,11 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/base",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
|
@ -31,8 +31,8 @@ EventPool::Handle::~Handle() {
|
|||||||
EventPool::EventPool(bool allow_reuse)
|
EventPool::EventPool(bool allow_reuse)
|
||||||
: allow_reuse_(allow_reuse), next_sequence_number_(0) {}
|
: allow_reuse_(allow_reuse), next_sequence_number_(0) {}
|
||||||
|
|
||||||
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
StatusOr<EventPool::Handle> EventPool::AllocateEvent(
|
||||||
se::Stream* stream) {
|
se::StreamExecutor* executor) {
|
||||||
Handle event;
|
Handle event;
|
||||||
|
|
||||||
if (allow_reuse_) {
|
if (allow_reuse_) {
|
||||||
@ -44,15 +44,24 @@ StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!event.event_) {
|
if (!event.event_) {
|
||||||
event.event_ = absl::make_unique<se::Event>(stream->parent());
|
event.event_ = absl::make_unique<se::Event>(executor);
|
||||||
TF_RET_CHECK(event.event_->Init()) << "Event initialization failed";
|
TF_RET_CHECK(event.event_->Init()) << "Event initialization failed";
|
||||||
}
|
}
|
||||||
{
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
stream->ThenRecordEvent(event.event_.get());
|
|
||||||
event.sequence_number_ = next_sequence_number_++;
|
|
||||||
}
|
|
||||||
return event;
|
return event;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EventPool::ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle) {
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
stream->ThenRecordEvent(handle.event_.get());
|
||||||
|
handle.sequence_number_ = next_sequence_number_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
||||||
|
se::Stream* stream) {
|
||||||
|
TF_ASSIGN_OR_RETURN(EventPool::Handle handle,
|
||||||
|
AllocateEvent(stream->parent()));
|
||||||
|
ThenRecordEvent(stream, handle);
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -77,6 +77,11 @@ class EventPool {
|
|||||||
// cudaEventRecord.
|
// cudaEventRecord.
|
||||||
StatusOr<Handle> ThenAllocateAndRecordEvent(se::Stream* stream);
|
StatusOr<Handle> ThenAllocateAndRecordEvent(se::Stream* stream);
|
||||||
|
|
||||||
|
// Version of ThenAllocateAndRecordEvent split into two phases; this is
|
||||||
|
// sometimes helpful if we want to avoid failures by preallocating events.
|
||||||
|
StatusOr<Handle> AllocateEvent(se::StreamExecutor* executor);
|
||||||
|
void ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const bool allow_reuse_;
|
const bool allow_reuse_;
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
@ -284,28 +285,31 @@ class PjRtBuffer {
|
|||||||
virtual StatusOr<std::unique_ptr<ExternalReferenceHold>>
|
virtual StatusOr<std::unique_ptr<ExternalReferenceHold>>
|
||||||
AcquireExternalReference() = 0;
|
AcquireExternalReference() = 0;
|
||||||
|
|
||||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
// Copies the buffer's value into `literal`. Calls `on_ready` when the value
|
||||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
// (or an error) is ready. The transfer respects the layout of `literal`; to
|
||||||
// copies the buffer to the host. Blocks until the value is ready. If
|
// specify a particular layout, set the layout before calling `ToLiteral`.
|
||||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
virtual void ToLiteral(MutableLiteralBase* literal,
|
||||||
// cached copy of the literal (i.e. The reference to the host value will be
|
std::function<void(Status)> on_ready) = 0;
|
||||||
// removed.) If a layout is passed than a literal with this layout will be
|
|
||||||
// returned.
|
|
||||||
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
|
|
||||||
return ToLiteral(/*discard_cached_copy=*/false, /*layout=*/{});
|
|
||||||
}
|
|
||||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(bool discard_cached_copy) {
|
|
||||||
return ToLiteral(discard_cached_copy, /*layout=*/{});
|
|
||||||
}
|
|
||||||
virtual StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) = 0;
|
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
// Synchronous overload of ToLiteral, as a convenience.
|
||||||
// the transfer to complete. The value can be retrieved by a later call to
|
Status ToLiteral(MutableLiteralBase* literal) {
|
||||||
// ToLiteral(). If a layout is passed then a cached copy with this layout will
|
absl::Notification done;
|
||||||
// be created.
|
Status status;
|
||||||
Status CopyToHostAsync() { return CopyToHostAsync(/*layout=*/{}); }
|
ToLiteral(literal, [&](Status s) {
|
||||||
virtual Status CopyToHostAsync(absl::optional<xla::Layout> layout) = 0;
|
status = std::move(s);
|
||||||
|
done.Notify();
|
||||||
|
});
|
||||||
|
done.WaitForNotification();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience synchronous overload that allocates a literal with a default
|
||||||
|
// layout.
|
||||||
|
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
|
||||||
|
auto literal = std::make_shared<Literal>(on_host_shape());
|
||||||
|
TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
|
||||||
|
return literal;
|
||||||
|
}
|
||||||
|
|
||||||
// Drops the buffer's reference to its associated device memory, leaving the
|
// Drops the buffer's reference to its associated device memory, leaving the
|
||||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||||
|
@ -986,10 +986,9 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
|
|||||||
if (device_buffer_ == nullptr) {
|
if (device_buffer_ == nullptr) {
|
||||||
return std::shared_ptr<TrackedDeviceBuffer>();
|
return std::shared_ptr<TrackedDeviceBuffer>();
|
||||||
}
|
}
|
||||||
// Clear host_values_ and set device_buffer_ to null now so that no other
|
// Set device_buffer_ to null now so that no other
|
||||||
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
|
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
|
||||||
// below.
|
// below.
|
||||||
host_values_.clear();
|
|
||||||
std::swap(device_buffer_, device_buffer);
|
std::swap(device_buffer_, device_buffer);
|
||||||
WaitForOutstandingUsageHolds();
|
WaitForOutstandingUsageHolds();
|
||||||
// Now that all holds have completed and no more can be added, we can get
|
// Now that all holds have completed and no more can be added, we can get
|
||||||
@ -1126,7 +1125,6 @@ void PjRtStreamExecutorBuffer::ConfirmDonation(
|
|||||||
device_buffer->ReleaseDeviceMemory();
|
device_buffer->ReleaseDeviceMemory();
|
||||||
// Make *this invalid so it can't be used again. Any threads blocking in
|
// Make *this invalid so it can't be used again. Any threads blocking in
|
||||||
// Release or GetBufferWithHold will see an invalid buffer and return.
|
// Release or GetBufferWithHold will see an invalid buffer and return.
|
||||||
host_values_.clear();
|
|
||||||
device_buffer_.reset();
|
device_buffer_.reset();
|
||||||
}
|
}
|
||||||
// Unblock another thread, if any, trying to get a donation hold.
|
// Unblock another thread, if any, trying to get a donation hold.
|
||||||
@ -1147,84 +1145,47 @@ void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtStreamExecutorBuffer::CopyToHostAsync(
|
void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
|
||||||
absl::optional<xla::Layout> layout) {
|
std::function<void(Status)> on_ready) {
|
||||||
return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout)
|
|
||||||
.status();
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<PjRtStreamExecutorBuffer::HostValue>>
|
|
||||||
PjRtStreamExecutorBuffer::CopyToHostAsyncInternal(
|
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
|
||||||
if (IsEmptyTuple()) {
|
if (IsEmptyTuple()) {
|
||||||
return InvalidArgument("CopyToHostAsync called on empty tuple");
|
on_ready(InvalidArgument("ToLiteral called on empty tuple"));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
|
||||||
std::shared_ptr<HostValue> host_value;
|
|
||||||
LocalDeviceState* local_device =
|
LocalDeviceState* local_device =
|
||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
->local_device_state();
|
->local_device_state();
|
||||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||||
const xla::Layout& host_layout =
|
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
// We can't perform any other action while a donation hold is in progress.
|
// We can't perform any other action while a donation hold is in progress.
|
||||||
WaitForOutstandingDonationHold();
|
WaitForOutstandingDonationHold();
|
||||||
if (device_buffer_ == nullptr) {
|
if (device_buffer_ == nullptr) {
|
||||||
return InvalidArgument(
|
on_ready(InvalidArgument(
|
||||||
"CopyToHostAsync() called on deleted or donated buffer");
|
"CopyToHostAsync() called on deleted or donated buffer"));
|
||||||
}
|
return;
|
||||||
if (discard_cached_copy) {
|
|
||||||
auto it = host_values_.find(host_layout);
|
|
||||||
if (it != host_values_.end()) {
|
|
||||||
host_value = it->second;
|
|
||||||
host_values_.erase(it);
|
|
||||||
return host_value;
|
|
||||||
} else {
|
|
||||||
host_value = std::make_shared<HostValue>();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::shared_ptr<HostValue>& host_value_ref = host_values_[host_layout];
|
|
||||||
if (host_value_ref) {
|
|
||||||
return host_value_ref;
|
|
||||||
}
|
|
||||||
host_value = host_value_ref = std::make_shared<HostValue>();
|
|
||||||
}
|
}
|
||||||
AcquireHoldLocked(&device_buffer);
|
AcquireHoldLocked(&device_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
||||||
Shape host_shape;
|
|
||||||
if (layout.has_value()) {
|
|
||||||
host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(),
|
|
||||||
on_host_shape_.dimensions());
|
|
||||||
*host_shape.mutable_layout() = host_layout;
|
|
||||||
} else {
|
|
||||||
host_shape = on_host_shape_;
|
|
||||||
}
|
|
||||||
host_value->value = std::make_shared<Literal>(host_shape);
|
|
||||||
ShapedBuffer shaped_buffer =
|
ShapedBuffer shaped_buffer =
|
||||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_);
|
||||||
|
StatusOr<EventPool::Handle> event_or =
|
||||||
|
local_device->event_pool().AllocateEvent(stream->parent());
|
||||||
|
if (!event_or.ok()) {
|
||||||
|
on_ready(event_or.status());
|
||||||
|
return;
|
||||||
|
}
|
||||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
->client()
|
->client()
|
||||||
->backend()
|
->backend()
|
||||||
.transfer_manager()
|
.transfer_manager()
|
||||||
->TransferLiteralFromDevice(stream, shaped_buffer,
|
->TransferLiteralFromDevice(stream, shaped_buffer, literal,
|
||||||
host_value->value.get(),
|
std::move(on_ready));
|
||||||
[host_value](Status done_status) {
|
|
||||||
host_value->status = done_status;
|
|
||||||
host_value->ready.Notify();
|
|
||||||
});
|
|
||||||
|
|
||||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||||
StatusOr<EventPool::Handle> event_or =
|
local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
|
||||||
local_device->event_pool().ThenAllocateAndRecordEvent(stream);
|
|
||||||
if (!event_or.ok()) {
|
|
||||||
// Allocating the event failed, so synchronize
|
|
||||||
// the host on the copy and then drop the device buffer hold.
|
|
||||||
StallStreamOnError(local_device, stream);
|
|
||||||
return event_or.status();
|
|
||||||
}
|
|
||||||
usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
|
usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
|
||||||
// When using the ComputeSynchronized allocation model, retain a reference to
|
// When using the ComputeSynchronized allocation model, retain a reference to
|
||||||
// the device_buffer until the copy completes, to ensure that the buffer isn't
|
// the device_buffer until the copy completes, to ensure that the buffer isn't
|
||||||
@ -1238,20 +1199,6 @@ PjRtStreamExecutorBuffer::CopyToHostAsyncInternal(
|
|||||||
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
|
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
|
||||||
stream,
|
stream,
|
||||||
/*prefer_to_retain_reference=*/true);
|
/*prefer_to_retain_reference=*/true);
|
||||||
return host_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<Literal>> PjRtStreamExecutorBuffer::ToLiteral(
|
|
||||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
|
||||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
|
||||||
if (host_value == nullptr) {
|
|
||||||
return InvalidArgument("ToLiteral called on deleted or donated buffer");
|
|
||||||
}
|
|
||||||
host_value->ready.WaitForNotification();
|
|
||||||
TF_RETURN_IF_ERROR(host_value->status);
|
|
||||||
return host_value->value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
|
StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
|
||||||
|
@ -497,11 +497,8 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
|||||||
bool wait_for_operations_to_complete) override;
|
bool wait_for_operations_to_complete) override;
|
||||||
|
|
||||||
using PjRtBuffer::ToLiteral;
|
using PjRtBuffer::ToLiteral;
|
||||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
void ToLiteral(MutableLiteralBase* literal,
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
|
std::function<void(Status)> on_ready) override;
|
||||||
|
|
||||||
using PjRtBuffer::CopyToHostAsync;
|
|
||||||
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
|
|
||||||
|
|
||||||
// Drops the buffer's reference to its associated device memory, leaving the
|
// Drops the buffer's reference to its associated device memory, leaving the
|
||||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||||
@ -558,16 +555,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtClient;
|
friend class PjRtClient;
|
||||||
// The cached value of the buffer on the host, produced either from a call to
|
|
||||||
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
|
||||||
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
|
|
||||||
struct HostValue {
|
|
||||||
absl::Notification ready;
|
|
||||||
// status and value are valid for reading only after `ready` has been
|
|
||||||
// notified.
|
|
||||||
Status status;
|
|
||||||
std::shared_ptr<Literal> value;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Blocks in mu_.Await until there are no more usage holds.
|
// Blocks in mu_.Await until there are no more usage holds.
|
||||||
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
@ -598,14 +585,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
|||||||
// successfully donated to an execution.
|
// successfully donated to an execution.
|
||||||
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
|
||||||
// the transfer to complete. A host value is returned and if
|
|
||||||
// `discard_cached_copy` is false stored in an internal buffer so that future
|
|
||||||
// transfers don't have to transfer the data from host again. If a layout is
|
|
||||||
// passed then a literal of this layout will be returned and possibly cached.
|
|
||||||
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
|
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout);
|
|
||||||
|
|
||||||
// Drops a hold without taking any other action. Does a sanity check that
|
// Drops a hold without taking any other action. Does a sanity check that
|
||||||
// buffer==device_buffer_ or device_buffer_==nullptr.
|
// buffer==device_buffer_ or device_buffer_==nullptr.
|
||||||
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
||||||
@ -624,9 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
|||||||
|
|
||||||
mutable absl::Mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
||||||
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
|
|
||||||
TF_GUARDED_BY(mu_);
|
|
||||||
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
|
||||||
// Count of holds on the buffer.
|
// Count of holds on the buffer.
|
||||||
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
||||||
// Semaphore used to ensure there is only one outstanding donation hold.
|
// Semaphore used to ensure there is only one outstanding donation hold.
|
||||||
|
@ -200,6 +200,7 @@ cc_library(
|
|||||||
":types",
|
":types",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
"//tensorflow/core/platform:fingerprint",
|
"//tensorflow/core/platform:fingerprint",
|
||||||
"//tensorflow/core/profiler:protos_all_cc",
|
"//tensorflow/core/profiler:protos_all_cc",
|
||||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
|
|
||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/pytypes.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -85,6 +88,64 @@ Status PyBuffer::BlockHostUntilReady() {
|
|||||||
return buffer_->BlockHostUntilReady();
|
return buffer_->BlockHostUntilReady();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status PyBuffer::CopyToHostAsync() {
|
||||||
|
if (!buffer_->IsOnCpu() && !host_value_) {
|
||||||
|
host_value_ = std::make_shared<HostValue>();
|
||||||
|
host_value_->value = std::make_shared<Literal>(buffer_->on_host_shape());
|
||||||
|
buffer_->ToLiteral(host_value_->value.get(),
|
||||||
|
[host_value{host_value_}](Status status) {
|
||||||
|
host_value->status = std::move(status);
|
||||||
|
host_value->ready.Notify();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
|
||||||
|
if (buffer_->IsDeleted()) {
|
||||||
|
return InvalidArgument("DeviceArray has been deleted.");
|
||||||
|
}
|
||||||
|
TF_RET_CHECK(buffer_->on_device_shape().IsArray());
|
||||||
|
// On CPU, we can return the value in a zero-copy way.
|
||||||
|
if (buffer_->IsOnCpu()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
py::dtype dtype,
|
||||||
|
PrimitiveTypeToDtype(buffer_->on_host_shape().element_type()));
|
||||||
|
// Objects that must be kept alive while the array is alive.
|
||||||
|
struct Hold {
|
||||||
|
py::object buffer;
|
||||||
|
std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
|
||||||
|
external_reference_hold;
|
||||||
|
};
|
||||||
|
auto hold = std::make_unique<Hold>();
|
||||||
|
TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
|
||||||
|
buffer_->AcquireExternalReference());
|
||||||
|
hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
|
||||||
|
void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
|
||||||
|
py::capsule hold_capsule(hold.release(),
|
||||||
|
[](void* h) { delete static_cast<Hold*>(h); });
|
||||||
|
py::array array(dtype, buffer_->on_host_shape().dimensions(),
|
||||||
|
ByteStridesForShape(buffer_->on_host_shape()), data,
|
||||||
|
hold_capsule);
|
||||||
|
array.attr("flags").attr("writeable") = Py_False;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||||
|
if (!host_value_->ready.HasBeenNotified()) {
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
host_value_->ready.WaitForNotification();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(host_value_->status);
|
||||||
|
TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
|
||||||
|
array.attr("flags").attr("writeable") = Py_False;
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
|
// TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
|
||||||
StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
||||||
if (buffer_->on_device_shape().IsTuple()) {
|
if (buffer_->on_device_shape().IsTuple()) {
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "pybind11/numpy.h"
|
#include "pybind11/numpy.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
@ -70,12 +71,12 @@ class PyBuffer : public DeviceArrayBase {
|
|||||||
|
|
||||||
void Delete() {
|
void Delete() {
|
||||||
buffer_->Delete();
|
buffer_->Delete();
|
||||||
npy_value_ = pybind11::none();
|
host_value_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns xla::InvalidArgument if the buffer has been deleted.
|
// Returns xla::InvalidArgument if the buffer has been deleted.
|
||||||
Status BlockHostUntilReady();
|
Status BlockHostUntilReady();
|
||||||
Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); }
|
Status CopyToHostAsync();
|
||||||
|
|
||||||
const Shape& shape() { return buffer_->on_host_shape(); }
|
const Shape& shape() { return buffer_->on_host_shape(); }
|
||||||
|
|
||||||
@ -102,8 +103,7 @@ class PyBuffer : public DeviceArrayBase {
|
|||||||
void SetStickyDevice(pybind11::object sticky_device);
|
void SetStickyDevice(pybind11::object sticky_device);
|
||||||
pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
|
pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
|
||||||
|
|
||||||
void SetNpyValue(pybind11::object npy_value) { npy_value_ = npy_value; }
|
StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj);
|
||||||
pybind11::object GetNpyValue() const { return npy_value_; }
|
|
||||||
|
|
||||||
void SetAval(pybind11::object aval);
|
void SetAval(pybind11::object aval);
|
||||||
pybind11::object GetAval() const { return aval_.value(); }
|
pybind11::object GetAval() const { return aval_.value(); }
|
||||||
@ -111,11 +111,16 @@ class PyBuffer : public DeviceArrayBase {
|
|||||||
private:
|
private:
|
||||||
friend class PyClient;
|
friend class PyClient;
|
||||||
|
|
||||||
|
struct HostValue {
|
||||||
|
absl::Notification ready;
|
||||||
|
Status status;
|
||||||
|
std::shared_ptr<xla::Literal> value;
|
||||||
|
};
|
||||||
std::shared_ptr<PyClient> client_;
|
std::shared_ptr<PyClient> client_;
|
||||||
std::unique_ptr<PjRtBuffer> buffer_;
|
std::unique_ptr<PjRtBuffer> buffer_;
|
||||||
std::shared_ptr<Traceback> traceback_;
|
std::shared_ptr<Traceback> traceback_;
|
||||||
// The host numpy array caching the value when it has been copied to the host.
|
std::shared_ptr<HostValue> host_value_; // Protected by the GIL.
|
||||||
pybind11::object npy_value_ = pybind11::none();
|
|
||||||
absl::optional<pybind11::object> sticky_device_ = absl::nullopt;
|
absl::optional<pybind11::object> sticky_device_ = absl::nullopt;
|
||||||
// TODO(jblespiau): It's currently there for convenience but maybe we can do
|
// TODO(jblespiau): It's currently there for convenience but maybe we can do
|
||||||
// without it (adding `weak_type` instead).
|
// without it (adding `weak_type` instead).
|
||||||
|
@ -68,25 +68,6 @@ bool IsOptimizedBuild() {
|
|||||||
#endif // NDEBUG
|
#endif // NDEBUG
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::object> BufferToPython(PyBuffer* buffer, py::handle& buffer_obj) {
|
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
|
||||||
if (buffer->buffer()->IsOnCpu() &&
|
|
||||||
buffer->buffer()->on_device_shape().IsArray() &&
|
|
||||||
buffer->buffer()->on_device_shape().element_type() != BF16) {
|
|
||||||
py::object out =
|
|
||||||
py::reinterpret_steal<py::object>(PyArray_FROM_O(buffer_obj.ptr()));
|
|
||||||
CHECK(out.ptr() != nullptr) << buffer->buffer()->on_host_shape().ToString(
|
|
||||||
/*print_layout=*/true);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
std::shared_ptr<Literal> literal;
|
|
||||||
{
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral());
|
|
||||||
}
|
|
||||||
return LiteralToPython(std::move(literal));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(xla_extension, m) {
|
PYBIND11_MODULE(xla_extension, m) {
|
||||||
@ -328,20 +309,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def_property_readonly("ndim", &PyBuffer::ndim)
|
.def_property_readonly("ndim", &PyBuffer::ndim)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"_value",
|
"_value",
|
||||||
[](py::handle buffer_obj) -> pybind11::object {
|
[](py::handle buffer_obj) -> StatusOr<pybind11::object> {
|
||||||
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
||||||
if (buffer->is_deleted()) {
|
return buffer->AsNumPyArray(buffer_obj);
|
||||||
throw std::runtime_error("DeviceArray has been deleted.");
|
|
||||||
}
|
|
||||||
py::object npy_value_ = buffer->GetNpyValue();
|
|
||||||
if (npy_value_.is_none()) {
|
|
||||||
npy_value_ = BufferToPython(buffer, buffer_obj).ValueOrDie();
|
|
||||||
// TODO(jblspiau): Change `LiteralToPython` to return a
|
|
||||||
// `py::array`, so we can set more easily the attribute.
|
|
||||||
npy_value_.attr("flags").attr("writeable") = Py_False;
|
|
||||||
buffer->SetNpyValue(npy_value_);
|
|
||||||
}
|
|
||||||
return npy_value_;
|
|
||||||
})
|
})
|
||||||
.def("copy_to_device", &PyBuffer::CopyToDevice)
|
.def("copy_to_device", &PyBuffer::CopyToDevice)
|
||||||
.def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
|
.def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
|
||||||
@ -359,7 +330,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("to_py",
|
.def("to_py",
|
||||||
[](py::handle buffer_obj) {
|
[](py::handle buffer_obj) {
|
||||||
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
||||||
return BufferToPython(buffer, buffer_obj);
|
return buffer->AsNumPyArray(buffer_obj);
|
||||||
})
|
})
|
||||||
.def("xla_shape", &PyBuffer::shape)
|
.def("xla_shape", &PyBuffer::shape)
|
||||||
.def_property_readonly("client", &PyBuffer::client)
|
.def_property_readonly("client", &PyBuffer::client)
|
||||||
|
Loading…
Reference in New Issue
Block a user