Allow a shape to be passed to CopyToHostAsync

PiperOrigin-RevId: 317611333
Change-Id: I4526f9dbd1b223eb23fe928326afca0eb133c2f5
This commit is contained in:
Tamara Norman 2020-06-22 01:49:32 -07:00 committed by TensorFlower Gardener
parent 149a0a1d5a
commit 0868ca7bb2
3 changed files with 67 additions and 26 deletions

View File

@ -149,6 +149,7 @@ cc_library(
"//tensorflow/stream_executor/host:host_platform_id",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
@ -156,6 +157,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -76,11 +76,13 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
@ -861,10 +863,10 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
if (device_buffer_ == nullptr) {
return std::shared_ptr<TrackedDeviceBuffer>();
}
// Set host_value_ and device_buffer_ to null now so that no other thread
// can add a hold while we are in WaitForOutstandingUsageHolds()
// Clear host_values_ and set device_buffer_ to null now so that no other
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
// below.
host_value_ = nullptr;
host_values_.clear();
std::swap(device_buffer_, device_buffer);
WaitForOutstandingUsageHolds();
// Now that all holds have completed and no more can be added, we can get
@ -999,7 +1001,7 @@ void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) {
device_buffer->ReleaseDeviceMemory();
// Make *this invalid so it can't be used again. Any threads blocking in
// Release or GetBufferWithHold will see an invalid buffer and return.
host_value_ = nullptr;
host_values_.clear();
device_buffer_.reset();
}
// Unblock another thread, if any, trying to get a donation hold.
@ -1019,7 +1021,14 @@ void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) {
}
}
Status PjRtBuffer::CopyToHostAsync() {
Status PjRtBuffer::CopyToHostAsync(absl::optional<xla::Layout> layout) {
return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout)
.status();
}
StatusOr<std::shared_ptr<PjRtBuffer::HostValue>>
PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
absl::optional<xla::Layout> layout) {
if (IsEmptyTuple()) {
return InvalidArgument("CopyToHostAsync called on empty tuple");
}
@ -1027,6 +1036,8 @@ Status PjRtBuffer::CopyToHostAsync() {
std::shared_ptr<HostValue> host_value;
LocalDeviceState* local_device = device_->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream();
const xla::Layout& host_layout =
layout.has_value() ? layout.value() : on_host_shape_.layout();
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
@ -1034,17 +1045,36 @@ Status PjRtBuffer::CopyToHostAsync() {
if (device_buffer_ == nullptr) {
return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
}
if (host_value_) {
// The host value has already been requested or is available.
return Status::OK();
if (discard_cached_copy) {
auto it = host_values_.find(host_layout);
if (it != host_values_.end()) {
host_value = it->second;
host_values_.erase(it);
return host_value;
} else {
host_value = std::make_shared<HostValue>();
}
} else {
std::shared_ptr<HostValue>& host_value_ref = host_values_[host_layout];
if (host_value_ref) {
return host_value_ref;
}
host_value = host_value_ref = std::make_shared<HostValue>();
}
host_value = host_value_ = std::make_shared<HostValue>();
AcquireHoldLocked(&device_buffer);
}
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
host_value->value = std::make_shared<Literal>(on_host_shape_);
Shape host_shape;
if (layout.has_value()) {
host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(),
on_host_shape_.dimensions());
*host_shape.mutable_layout() = host_layout;
} else {
host_shape = on_host_shape_;
}
host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(
on_host_shape_, on_device_shape_, client_->client()->platform());
host_shape, on_device_shape_, client_->client()->platform());
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
stream, shaped_buffer, host_value->value.get(),
[host_value](Status done_status) {
@ -1074,21 +1104,14 @@ Status PjRtBuffer::CopyToHostAsync() {
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
stream,
/*prefer_to_retain_reference=*/true);
return Status::OK();
return host_value;
}
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy) {
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value;
{
absl::MutexLock lock(&mu_);
host_value = host_value_;
if (discard_cached_copy) {
host_value_ = nullptr;
}
}
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) {
return InvalidArgument("ToLiteral called on invalid buffer");
}

View File

@ -20,15 +20,18 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
@ -481,14 +484,17 @@ class PjRtBuffer {
// copies the buffer to the host. Blocks until the value is ready. If
// `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.)
// removed.) If a layout is passed than a literal with this layout will be
// returned.
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy = false);
bool discard_cached_copy = false,
absl::optional<xla::Layout> layout = {});
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to
// ToLiteral().
Status CopyToHostAsync();
// ToLiteral(). If a layout is passed then a cached copy with this layout will
// be created.
Status CopyToHostAsync(absl::optional<xla::Layout> layout = {});
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
@ -596,6 +602,14 @@ class PjRtBuffer {
// successfully donated to an execution.
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. A host value is returned and if
// `discard_cached_copy` is false stored in an internal buffer so that future
// transfers don't have to transfer the data from host again. If a layout is
// passed then a literal of this layout will be returned and possibly cached.
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
bool discard_cached_copy, absl::optional<xla::Layout> layout);
// Drops a hold without taking any other action. Does a sanity check that
// buffer==device_buffer_ or device_buffer_==nullptr.
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
@ -614,6 +628,8 @@ class PjRtBuffer {
mutable absl::Mutex mu_;
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
TF_GUARDED_BY(mu_);
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
// Count of holds on the buffer.
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);