diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index e401a798d68..695ba9dee93 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -149,6 +149,7 @@ cc_library( "//tensorflow/stream_executor/host:host_platform_id", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", @@ -156,6 +157,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index b4f0363e69a..e341a11d64f 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -76,11 +76,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" @@ -861,10 +863,10 @@ StatusOr> PjRtBuffer::Release( if (device_buffer_ == nullptr) { return std::shared_ptr(); } - // Set host_value_ and device_buffer_ to null now so that no other thread - // can add a hold while we are in WaitForOutstandingUsageHolds() + // Clear host_values_ and set device_buffer_ to null now so that no other + // thread can add a hold while we are in WaitForOutstandingUsageHolds() // below. - host_value_ = nullptr; + host_values_.clear(); std::swap(device_buffer_, device_buffer); WaitForOutstandingUsageHolds(); // Now that all holds have completed and no more can be added, we can get @@ -999,7 +1001,7 @@ void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) { device_buffer->ReleaseDeviceMemory(); // Make *this invalid so it can't be used again. Any threads blocking in // Release or GetBufferWithHold will see an invalid buffer and return. - host_value_ = nullptr; + host_values_.clear(); device_buffer_.reset(); } // Unblock another thread, if any, trying to get a donation hold. @@ -1019,7 +1021,14 @@ void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) { } } -Status PjRtBuffer::CopyToHostAsync() { +Status PjRtBuffer::CopyToHostAsync(absl::optional layout) { + return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout) + .status(); +} + +StatusOr> +PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, + absl::optional layout) { if (IsEmptyTuple()) { return InvalidArgument("CopyToHostAsync called on empty tuple"); } @@ -1027,6 +1036,8 @@ Status PjRtBuffer::CopyToHostAsync() { std::shared_ptr host_value; LocalDeviceState* local_device = device_->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); + const xla::Layout& host_layout = + layout.has_value() ? layout.value() : on_host_shape_.layout(); { absl::MutexLock lock(&mu_); // We can't perform any other action while a donation hold is in progress. @@ -1034,17 +1045,36 @@ Status PjRtBuffer::CopyToHostAsync() { if (device_buffer_ == nullptr) { return InvalidArgument("CopyToHostAsync() called on invalid buffer."); } - if (host_value_) { - // The host value has already been requested or is available. - return Status::OK(); + if (discard_cached_copy) { + auto it = host_values_.find(host_layout); + if (it != host_values_.end()) { + host_value = it->second; + host_values_.erase(it); + return host_value; + } else { + host_value = std::make_shared(); + } + } else { + std::shared_ptr& host_value_ref = host_values_[host_layout]; + if (host_value_ref) { + return host_value_ref; + } + host_value = host_value_ref = std::make_shared(); } - host_value = host_value_ = std::make_shared(); AcquireHoldLocked(&device_buffer); } WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - host_value->value = std::make_shared(on_host_shape_); + Shape host_shape; + if (layout.has_value()) { + host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(), + on_host_shape_.dimensions()); + *host_shape.mutable_layout() = host_layout; + } else { + host_shape = on_host_shape_; + } + host_value->value = std::make_shared(host_shape); ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer( - on_host_shape_, on_device_shape_, client_->client()->platform()); + host_shape, on_device_shape_, client_->client()->platform()); client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( stream, shaped_buffer, host_value->value.get(), [host_value](Status done_status) { @@ -1074,21 +1104,14 @@ Status PjRtBuffer::CopyToHostAsync() { RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, stream, /*prefer_to_retain_reference=*/true); - return Status::OK(); + return host_value; } StatusOr> PjRtBuffer::ToLiteral( - const bool discard_cached_copy) { + const bool discard_cached_copy, absl::optional layout) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); - TF_RETURN_IF_ERROR(CopyToHostAsync()); - std::shared_ptr host_value; - { - absl::MutexLock lock(&mu_); - host_value = host_value_; - if (discard_cached_copy) { - host_value_ = nullptr; - } - } + TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, + CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { return InvalidArgument("ToLiteral called on invalid buffer"); } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 8f74e6244d6..c50d09f631c 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -20,15 +20,18 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -481,14 +484,17 @@ class PjRtBuffer { // copies the buffer to the host. Blocks until the value is ready. If // `discard_cached_copy` is true then buffer will no longer keep hold of a // cached copy of the literal (i.e. The reference to the host value will be - // removed.) + // removed.) If a layout is passed than a literal with this layout will be + // returned. StatusOr> ToLiteral( - bool discard_cached_copy = false); + bool discard_cached_copy = false, + absl::optional layout = {}); // Initiates a copy of the buffer to the host. Does not block waiting for // the transfer to complete. The value can be retrieved by a later call to - // ToLiteral(). - Status CopyToHostAsync(); + // ToLiteral(). If a layout is passed then a cached copy with this layout will + // be created. + Status CopyToHostAsync(absl::optional layout = {}); // Drops the buffer's reference to its associated device memory, leaving the // buffer in an invalid state. The memory will be freed lazily when all async @@ -596,6 +602,14 @@ class PjRtBuffer { // successfully donated to an execution. void ConfirmDonation(TrackedDeviceBuffer* device_buffer); + // Initiates a copy of the buffer to the host. Does not block waiting for + // the transfer to complete. A host value is returned and if + // `discard_cached_copy` is false stored in an internal buffer so that future + // transfers don't have to transfer the data from host again. If a layout is + // passed then a literal of this layout will be returned and possibly cached. + StatusOr> CopyToHostAsyncInternal( + bool discard_cached_copy, absl::optional layout); + // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); @@ -614,6 +628,8 @@ class PjRtBuffer { mutable absl::Mutex mu_; std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); + absl::flat_hash_map> host_values_ + TF_GUARDED_BY(mu_); std::shared_ptr host_value_ TF_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ TF_GUARDED_BY(mu_);