Allow a shape to be passed to CopyToHostAsync
PiperOrigin-RevId: 317611333 Change-Id: I4526f9dbd1b223eb23fe928326afca0eb133c2f5
This commit is contained in:
parent
149a0a1d5a
commit
0868ca7bb2
@ -149,6 +149,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -156,6 +157,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -76,11 +76,13 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
|
||||
@ -861,10 +863,10 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
|
||||
if (device_buffer_ == nullptr) {
|
||||
return std::shared_ptr<TrackedDeviceBuffer>();
|
||||
}
|
||||
// Set host_value_ and device_buffer_ to null now so that no other thread
|
||||
// can add a hold while we are in WaitForOutstandingUsageHolds()
|
||||
// Clear host_values_ and set device_buffer_ to null now so that no other
|
||||
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
|
||||
// below.
|
||||
host_value_ = nullptr;
|
||||
host_values_.clear();
|
||||
std::swap(device_buffer_, device_buffer);
|
||||
WaitForOutstandingUsageHolds();
|
||||
// Now that all holds have completed and no more can be added, we can get
|
||||
@ -999,7 +1001,7 @@ void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) {
|
||||
device_buffer->ReleaseDeviceMemory();
|
||||
// Make *this invalid so it can't be used again. Any threads blocking in
|
||||
// Release or GetBufferWithHold will see an invalid buffer and return.
|
||||
host_value_ = nullptr;
|
||||
host_values_.clear();
|
||||
device_buffer_.reset();
|
||||
}
|
||||
// Unblock another thread, if any, trying to get a donation hold.
|
||||
@ -1019,7 +1021,14 @@ void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) {
|
||||
}
|
||||
}
|
||||
|
||||
Status PjRtBuffer::CopyToHostAsync() {
|
||||
Status PjRtBuffer::CopyToHostAsync(absl::optional<xla::Layout> layout) {
|
||||
return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout)
|
||||
.status();
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtBuffer::HostValue>>
|
||||
PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
absl::optional<xla::Layout> layout) {
|
||||
if (IsEmptyTuple()) {
|
||||
return InvalidArgument("CopyToHostAsync called on empty tuple");
|
||||
}
|
||||
@ -1027,6 +1036,8 @@ Status PjRtBuffer::CopyToHostAsync() {
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
LocalDeviceState* local_device = device_->local_device_state();
|
||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||
const xla::Layout& host_layout =
|
||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
// We can't perform any other action while a donation hold is in progress.
|
||||
@ -1034,17 +1045,36 @@ Status PjRtBuffer::CopyToHostAsync() {
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
|
||||
}
|
||||
if (host_value_) {
|
||||
// The host value has already been requested or is available.
|
||||
return Status::OK();
|
||||
if (discard_cached_copy) {
|
||||
auto it = host_values_.find(host_layout);
|
||||
if (it != host_values_.end()) {
|
||||
host_value = it->second;
|
||||
host_values_.erase(it);
|
||||
return host_value;
|
||||
} else {
|
||||
host_value = std::make_shared<HostValue>();
|
||||
}
|
||||
} else {
|
||||
std::shared_ptr<HostValue>& host_value_ref = host_values_[host_layout];
|
||||
if (host_value_ref) {
|
||||
return host_value_ref;
|
||||
}
|
||||
host_value = host_value_ref = std::make_shared<HostValue>();
|
||||
}
|
||||
host_value = host_value_ = std::make_shared<HostValue>();
|
||||
AcquireHoldLocked(&device_buffer);
|
||||
}
|
||||
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
||||
host_value->value = std::make_shared<Literal>(on_host_shape_);
|
||||
Shape host_shape;
|
||||
if (layout.has_value()) {
|
||||
host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(),
|
||||
on_host_shape_.dimensions());
|
||||
*host_shape.mutable_layout() = host_layout;
|
||||
} else {
|
||||
host_shape = on_host_shape_;
|
||||
}
|
||||
host_value->value = std::make_shared<Literal>(host_shape);
|
||||
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(
|
||||
on_host_shape_, on_device_shape_, client_->client()->platform());
|
||||
host_shape, on_device_shape_, client_->client()->platform());
|
||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, shaped_buffer, host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
@ -1074,21 +1104,14 @@ Status PjRtBuffer::CopyToHostAsync() {
|
||||
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
|
||||
stream,
|
||||
/*prefer_to_retain_reference=*/true);
|
||||
return Status::OK();
|
||||
return host_value;
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy) {
|
||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
host_value = host_value_;
|
||||
if (discard_cached_copy) {
|
||||
host_value_ = nullptr;
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||
if (host_value == nullptr) {
|
||||
return InvalidArgument("ToLiteral called on invalid buffer");
|
||||
}
|
||||
|
@ -20,15 +20,18 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
@ -481,14 +484,17 @@ class PjRtBuffer {
|
||||
// copies the buffer to the host. Blocks until the value is ready. If
|
||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
||||
// cached copy of the literal (i.e. The reference to the host value will be
|
||||
// removed.)
|
||||
// removed.) If a layout is passed than a literal with this layout will be
|
||||
// returned.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||
bool discard_cached_copy = false);
|
||||
bool discard_cached_copy = false,
|
||||
absl::optional<xla::Layout> layout = {});
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. The value can be retrieved by a later call to
|
||||
// ToLiteral().
|
||||
Status CopyToHostAsync();
|
||||
// ToLiteral(). If a layout is passed then a cached copy with this layout will
|
||||
// be created.
|
||||
Status CopyToHostAsync(absl::optional<xla::Layout> layout = {});
|
||||
|
||||
// Drops the buffer's reference to its associated device memory, leaving the
|
||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||
@ -596,6 +602,14 @@ class PjRtBuffer {
|
||||
// successfully donated to an execution.
|
||||
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. A host value is returned and if
|
||||
// `discard_cached_copy` is false stored in an internal buffer so that future
|
||||
// transfers don't have to transfer the data from host again. If a layout is
|
||||
// passed then a literal of this layout will be returned and possibly cached.
|
||||
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
|
||||
bool discard_cached_copy, absl::optional<xla::Layout> layout);
|
||||
|
||||
// Drops a hold without taking any other action. Does a sanity check that
|
||||
// buffer==device_buffer_ or device_buffer_==nullptr.
|
||||
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
||||
@ -614,6 +628,8 @@ class PjRtBuffer {
|
||||
|
||||
mutable absl::Mutex mu_;
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
|
||||
TF_GUARDED_BY(mu_);
|
||||
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
||||
// Count of holds on the buffer.
|
||||
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
||||
|
Loading…
Reference in New Issue
Block a user