From eccc1d4c102b1b3b03b98dbc362f799cb540a1da Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 12 Jul 2018 03:33:08 -0700 Subject: [PATCH] [XLA:GPU] Unify infeed and outfeed queue implementations This makes the infeed queue behave like the outfeed queue and merges the two implementations. It shouldn't change functionality. There was also quite a bit of unused code in infeed_manager that's gone now. PiperOrigin-RevId: 204273197 --- tensorflow/compiler/xla/service/gpu/BUILD | 11 ++- .../xla/service/gpu/gpu_transfer_manager.cc | 46 ++++------ .../xla/service/gpu/gpu_transfer_manager.h | 7 +- .../xla/service/gpu/infeed_manager.cc | 69 +------------- .../compiler/xla/service/gpu/infeed_manager.h | 82 ++++------------- .../compiler/xla/service/gpu/infeed_thunk.cc | 16 ++-- .../xla/service/gpu/outfeed_manager.cc | 19 ---- .../xla/service/gpu/outfeed_manager.h | 27 +----- .../compiler/xla/service/gpu/outfeed_thunk.cc | 2 +- .../compiler/xla/service/gpu/xfeed_queue.h | 89 +++++++++++++++++++ 10 files changed, 146 insertions(+), 222 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/xfeed_queue.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9fca3a51c8f..59172e53d3f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -542,6 +542,7 @@ cc_library( ":outfeed_manager", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -639,14 +640,21 @@ cc_library( ], ) +cc_library( + name = "xfeed_queue", + hdrs = ["xfeed_queue.h"], + deps = ["//tensorflow/core:lib"], +) + cc_library( name = "infeed_manager", srcs = ["infeed_manager.cc"], hdrs = ["infeed_manager.h"], deps = [ + ":xfeed_queue", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -656,6 +664,7 @@ cc_library( srcs = ["outfeed_manager.cc"], hdrs = ["outfeed_manager.h"], deps = [ + ":xfeed_queue", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 3c8018a0309..63466539fae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { +namespace gpu { // TODO(b/30467474) Once GPU infeed implementation settles, consider // folding back the cpu and gpu infeed implementations into a generic @@ -52,48 +53,37 @@ Status GpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { - int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.untyped_data()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. - std::vector buffers; - auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } - }); + ShapeTree buffer_tree(shape); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(literal_subshape)) { int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, + *buffer_tree.mutable_element(index), TransferBufferToInfeedInternal(executor, tuple_element_size, literal.untyped_data(index))); - buffers.push_back(buffer); } return Status::OK(); })); - cleanup.release(); - return EnqueueBuffersToInfeed(executor, buffers); + return EnqueueBuffersToInfeed(executor, std::move(buffer_tree)); } Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) { - TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, size, source)); - return EnqueueBuffersToInfeed(executor, {buffer}); + return InternalError( + "Attempted to transfer data to infeed on a GPU device using " + "TransferBufferToInfeed. This should be done using " + "TransferLiteralToInfeed instead."); } Status GpuTransferManager::EnqueueBuffersToInfeed( - se::StreamExecutor* executor, std::vector buffers) { + se::StreamExecutor* executor, ShapeTree buffers) { gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); se::Stream* stream = infeed_manager->GetStream(executor); @@ -103,21 +93,18 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // possible. Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } return InternalError("Failed to complete data transfer on stream %p: %s", stream, block_status.error_message().c_str()); } - infeed_manager->EnqueueBuffers(buffers); + infeed_manager->EnqueueDestination(std::move(buffers)); VLOG(2) << "Infeed data transferred"; return Status::OK(); } -StatusOr GpuTransferManager::TransferBufferToInfeedInternal( +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); @@ -133,12 +120,12 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return InternalError("Failed to obtain a stream"); } - gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), source, size); + InfeedBuffer buffer(executor, size); + stream->ThenMemcpy(buffer.device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - return buffer; + return std::move(buffer); } static std::unique_ptr ShapeTreeToLiteral( @@ -191,17 +178,18 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( // Give the tree of buffers to the outfeed mananger. The device will fill it // while we're waiting for it below. gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); - outfeed_manager->EnqueueOutfeedDestination(&outfeed_buffers); + outfeed_manager->EnqueueDestination(&outfeed_buffers); // Now turn the tree of buffers back into a literal. *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers)); return Status::OK(); } +} // namespace gpu } // namespace xla static std::unique_ptr CreateGpuTransferManager() { - return xla::MakeUnique(); + return xla::MakeUnique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 9dff1e5a507..7a5fe6979f3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { +namespace gpu { // An implementation of the XLA GenericTransferManager that // handles GPU-specific infeed. @@ -47,17 +49,18 @@ class GpuTransferManager : public GenericTransferManager { private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. - StatusOr TransferBufferToInfeedInternal( + StatusOr TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source); // Enqueues infeed data buffers with the infeed manager after their // transfer completes. Status EnqueueBuffersToInfeed(se::StreamExecutor* executor, - std::vector buffers); + ShapeTree buffers); TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; +} // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index ae310beefad..c5f0cdf6cd5 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,76 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} - -void InfeedManager::Reset() { - tensorflow::mutex_lock l(mu_); - CHECK(dequeued_buffer_.empty()); - for (auto buffer : enqueued_buffer_) { - buffer->Done(); - } - enqueued_buffer_.clear(); -} - -void InfeedManager::EnqueueBuffers(const std::vector& buffers) { - tensorflow::mutex_lock l(mu_); - bool was_empty = enqueued_buffer_.empty(); - for (gpu::InfeedBuffer* b : buffers) { - enqueued_buffer_.push_back(b); - } - if (was_empty) { - // This has the potential to suffer from the notified thread - // immediately trying and failing to acquire mu_, but seems - // preferable to the alternative of notifying outside the lock - // on every enqueue. - cv_.notify_one(); - } -} - -InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - bool became_empty = false; - InfeedBuffer* current_buffer; - { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); - } - current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); - if (enqueued_buffer_.empty()) { - became_empty = true; - } - } - if (became_empty) { - for (const auto& callback : on_empty_callbacks_) { - callback(); - } - } - return current_buffer; -} - -void InfeedManager::ReleaseBuffers(const std::vector& buffers) { - { - tensorflow::mutex_lock l(mu_); - for (gpu::InfeedBuffer* b : buffers) { - CHECK(ContainsKey(dequeued_buffer_, b)); - dequeued_buffer_.erase(b); - } - } - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } -} - se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { + tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; host_to_device_stream_ = MakeUnique(executor); @@ -100,10 +37,6 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } -void InfeedManager::RegisterOnEmptyCallback(std::function callback) { - on_empty_callbacks_.push_back(std::move(callback)); -} - InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index a3fc15cfe36..7e418882e05 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -20,12 +20,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ -#include -#include - +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -47,90 +44,41 @@ namespace gpu { // the client. The client manages the memory of the buffer. class InfeedBuffer { public: + InfeedBuffer() = default; InfeedBuffer(se::StreamExecutor* executor, int64 length) - : executor_(executor), length_(length) { - device_memory_ = executor_->AllocateArray(length); - CHECK(!device_memory_.is_null()); + : device_memory_(executor, executor->AllocateArray(length)), + length_(length) { + CHECK(!device_memory_->is_null()); } - ~InfeedBuffer() { executor_->Deallocate(&device_memory_); } - int64 length() const { return length_; } - // Callback to signal that this buffer is consumed. This helps the - // client to manage memory for the infeed buffers. - void Done() { delete this; } - - se::DeviceMemoryBase* device_memory() { return &device_memory_; } + se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); } private: - se::StreamExecutor* executor_; // Not owned. - const int64 length_; - se::DeviceMemoryBase device_memory_; + se::ScopedDeviceMemory device_memory_; + int64 length_; }; // Client-side class used to enqueue infeed buffers. -class InfeedManager { +class InfeedManager : public XfeedQueue> { public: - InfeedManager(); - - // Calls the completion callback for any enqueued buffers that have - // not been dequeued by the runtime, and empties the infeed - // queue. Reset may not be called while a runtime computation is - // processing a dequeued buffer. The only safe way to ensure this - // condition is to call Reset when no computation is taking place. - void Reset(); - - // Adds a set of buffers to the infeed queue atomically. buffer->Done - // will be called when the buffer will no longer be accessed by the - // InfeedManager, either as a result of a call to Reset or because the - // runtime has dequeued and used the buffer. - void EnqueueBuffers(const std::vector& buffers); - - // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Adds the current buffer to the - // to-be released set. - InfeedBuffer* BlockingDequeueBuffer(); - - // Releases a set of buffers from the to-be released set. - void ReleaseBuffers(const std::vector& buffers); - // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if // the cached executor is not the same as the requested executor, // returns null. se::Stream* GetStream(se::StreamExecutor* executor); - // Registers a callback that will be called when 'enqueued_buffer_' becomes - // empty. - void RegisterOnEmptyCallback(std::function callback); - private: - // TODO(b/30467474): Revisit if this mutex becomes a point of - // contention. - tensorflow::mutex mu_; - - // Condition variable that is signaled every time a buffer is - // enqueued to an empty queue. - tensorflow::condition_variable cv_; - - // InfeedBuffer* queue contents are not owned, but buffer->Done must - // be called when the buffer is no longer needed by the runtime. - std::deque enqueued_buffer_; - - // Buffers that are dequeued and currently being processed by the - // runtime. Not owned. - tensorflow::gtl::FlatSet dequeued_buffer_; + // Mutex for serializing the creation of host_to_device_stream_. + tensorflow::mutex host_to_device_stream_mu_; // Cached host to device stream for queuing infeed data. - std::unique_ptr host_to_device_stream_; + std::unique_ptr host_to_device_stream_ + GUARDED_BY(host_to_device_stream_mu_); // Executor that the host_to_device_stream belongs to. Not owned. - se::StreamExecutor* host_to_device_executor_; - - // List of callbacks which will be called when 'enqueued_buffer_' becomes - // empty. - std::vector> on_empty_callbacks_; + se::StreamExecutor* host_to_device_executor_ = nullptr; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 62915febb11..964efd36573 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -38,9 +38,10 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase data_address = buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - std::vector infeed_buffers; const Shape& data_shape = ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); + ShapeTree infeed_buffers = + infeed_manager->BlockingGetNextDestination(); if (ShapeUtil::IsTuple(data_shape)) { CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // Transfer the tuple elements first. @@ -51,8 +52,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); + InfeedBuffer* buffer = infeed_buffers.mutable_element({i}); stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), buffer->length()); tuple_element_addresses.push_back(tuple_element_address.opaque()); @@ -62,19 +62,17 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), host_size); } else { - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); + InfeedBuffer* buffer = infeed_buffers.mutable_element({}); stream->ThenMemcpy(&data_address, *(buffer->device_memory()), buffer->length()); } // Construct top-level tuple of infeed containing the data and the token. Use // a nullptr for the token, it should never be dereferenced. - std::vector infeed_addresses = {data_address.opaque(), nullptr}; + void* infeed_addresses[] = {data_address.opaque(), nullptr}; se::DeviceMemoryBase top_level_address = buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); - stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), - 2 * sizeof(void*)); + stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*)); Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { @@ -82,8 +80,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, block_status.error_message().c_str()); } - infeed_manager->ReleaseBuffers(infeed_buffers); - VLOG(2) << "Infeeding to GPU complete"; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 47744548b99..4aaf0c9e142 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -23,25 +23,6 @@ limitations under the License. namespace xla { namespace gpu { -void OutfeedManager::EnqueueOutfeedDestination( - ShapeTree>* buffers) { - tensorflow::mutex_lock l(mu_); - enqueued_buffers_.push_back(buffers); - cv_.notify_one(); -} - -ShapeTree>* -OutfeedManager::BlockingGetNextOutfeedDestination() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffers_.empty()) { - cv_.wait(l); - } - ShapeTree>* current_buffer = - enqueued_buffers_.front(); - enqueued_buffers_.pop_front(); - return current_buffer; -} - OutfeedManager* GetOrCreateOutfeedManager() { static auto* manager = new OutfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h index f580c24e17f..a752eb70119 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -16,10 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ -#include -#include - #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -60,28 +58,7 @@ class OutfeedBuffer { // Manages a thread-safe queue of buffers. The buffers are supposed to be // produced by the transfer manager and consumed by the device. -class OutfeedManager { - public: - // Adds a tree of buffers to the queue. The individual buffers correspond to - // the elements of a tuple and may be nullptr if the buffer is a tuple index - // buffer. - void EnqueueOutfeedDestination( - ShapeTree>* buffers); - - // Blocks until the queue is non-empty, then returns the buffer at the head of - // the queue. - ShapeTree>* - BlockingGetNextOutfeedDestination(); - - private: - tensorflow::mutex mu_; - - // Condition variable that is signaled every time a buffer is enqueued. - tensorflow::condition_variable cv_; - - // The queue of trees of buffers. OutfeedBuffer* queue contents are not owned. - std::deque>*> enqueued_buffers_; -}; +using OutfeedManager = XfeedQueue>*>; // Singleton creator-or-accessor: Returns the GPU outfeed manager. OutfeedManager* GetOrCreateOutfeedManager(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 4c0f1421e9f..7986e63f43e 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -36,7 +36,7 @@ Status OutfeedThunk::ExecuteOnStream( auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); ShapeTree>* outfeed_buffers = - outfeed_manager->BlockingGetNextOutfeedDestination(); + outfeed_manager->BlockingGetNextDestination(); // Nothing to be done for empty tuples. if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) { diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h new file mode 100644 index 00000000000..737c7eb0253 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ + +#include +#include + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU outfeed implementation settles, consider +// folding back the cpu and gpu outfeed implementations into a generic +// one if possible. + +// Manages a thread-safe queue of buffers. +template +class XfeedQueue { + public: + // Adds a tree of buffers to the queue. The individual buffers correspond to + // the elements of a tuple and may be nullptr if the buffer is a tuple index + // buffer. + void EnqueueDestination(BufferType buffers) { + tensorflow::mutex_lock l(mu_); + enqueued_buffers_.push_back(std::move(buffers)); + cv_.notify_one(); + } + + // Blocks until the queue is non-empty, then returns the buffer at the head of + // the queue. + BufferType BlockingGetNextDestination() { + bool became_empty; + BufferType current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffers_.empty()) { + cv_.wait(l); + } + current_buffer = std::move(enqueued_buffers_.front()); + enqueued_buffers_.pop_front(); + became_empty = enqueued_buffers_.empty(); + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } + } + return current_buffer; + } + + void RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); + } + + private: + tensorflow::mutex mu_; + + // Condition variable that is signaled every time a buffer is enqueued. + tensorflow::condition_variable cv_; + + // The queue of trees of buffers. Buffer* queue contents are not owned. + std::deque enqueued_buffers_ GUARDED_BY(mu_); + + // List of callbacks which will be called when 'enqueued_buffers_' becomes + // empty. + std::vector> on_empty_callbacks_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_