[XLA:GPU] Unify infeed and outfeed queue implementations
This makes the infeed queue behave like the outfeed queue and merges the two implementations. It shouldn't change functionality. There was also quite a bit of unused code in infeed_manager that's gone now. PiperOrigin-RevId: 204273197
This commit is contained in:
parent
6cc6383922
commit
eccc1d4c10
@ -542,6 +542,7 @@ cc_library(
|
||||
":outfeed_manager",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -639,14 +640,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xfeed_queue",
|
||||
hdrs = ["xfeed_queue.h"],
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "infeed_manager",
|
||||
srcs = ["infeed_manager.cc"],
|
||||
hdrs = ["infeed_manager.h"],
|
||||
deps = [
|
||||
":xfeed_queue",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
@ -656,6 +664,7 @@ cc_library(
|
||||
srcs = ["outfeed_manager.cc"],
|
||||
hdrs = ["outfeed_manager.h"],
|
||||
deps = [
|
||||
":xfeed_queue",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// TODO(b/30467474) Once GPU infeed implementation settles, consider
|
||||
// folding back the cpu and gpu infeed implementations into a generic
|
||||
@ -52,48 +53,37 @@ Status GpuTransferManager::TransferLiteralToInfeed(
|
||||
VLOG(2) << "Transferring literal to infeed with shape: "
|
||||
<< ShapeUtil::HumanString(shape);
|
||||
|
||||
if (!ShapeUtil::IsTuple(shape)) {
|
||||
int64 size = GetByteSizeRequirement(shape);
|
||||
return TransferBufferToInfeed(executor, size, literal.untyped_data());
|
||||
}
|
||||
|
||||
// For a tuple, we transfer each of its elements to the device and
|
||||
// enqueue the resulting destination device addresses with the
|
||||
// infeed manager.
|
||||
std::vector<gpu::InfeedBuffer*> buffers;
|
||||
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
|
||||
for (gpu::InfeedBuffer* b : buffers) {
|
||||
b->Done();
|
||||
}
|
||||
});
|
||||
ShapeTree<InfeedBuffer> buffer_tree(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
shape, [&](const Shape& literal_subshape, const ShapeIndex& index) {
|
||||
if (ShapeUtil::IsArray(literal_subshape)) {
|
||||
int64 tuple_element_size = GetByteSizeRequirement(literal_subshape);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
gpu::InfeedBuffer * buffer,
|
||||
*buffer_tree.mutable_element(index),
|
||||
TransferBufferToInfeedInternal(executor, tuple_element_size,
|
||||
literal.untyped_data(index)));
|
||||
buffers.push_back(buffer);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
cleanup.release();
|
||||
return EnqueueBuffersToInfeed(executor, buffers);
|
||||
return EnqueueBuffersToInfeed(executor, std::move(buffer_tree));
|
||||
}
|
||||
|
||||
Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
|
||||
int64 size,
|
||||
const void* source) {
|
||||
TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer,
|
||||
TransferBufferToInfeedInternal(executor, size, source));
|
||||
return EnqueueBuffersToInfeed(executor, {buffer});
|
||||
return InternalError(
|
||||
"Attempted to transfer data to infeed on a GPU device using "
|
||||
"TransferBufferToInfeed. This should be done using "
|
||||
"TransferLiteralToInfeed instead.");
|
||||
}
|
||||
|
||||
Status GpuTransferManager::EnqueueBuffersToInfeed(
|
||||
se::StreamExecutor* executor, std::vector<gpu::InfeedBuffer*> buffers) {
|
||||
se::StreamExecutor* executor, ShapeTree<InfeedBuffer> buffers) {
|
||||
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager();
|
||||
se::Stream* stream = infeed_manager->GetStream(executor);
|
||||
|
||||
@ -103,21 +93,18 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
|
||||
// possible.
|
||||
Status block_status = stream->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
for (gpu::InfeedBuffer* b : buffers) {
|
||||
b->Done();
|
||||
}
|
||||
return InternalError("Failed to complete data transfer on stream %p: %s",
|
||||
stream, block_status.error_message().c_str());
|
||||
}
|
||||
|
||||
infeed_manager->EnqueueBuffers(buffers);
|
||||
infeed_manager->EnqueueDestination(std::move(buffers));
|
||||
|
||||
VLOG(2) << "Infeed data transferred";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<gpu::InfeedBuffer*> GpuTransferManager::TransferBufferToInfeedInternal(
|
||||
StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
|
||||
se::StreamExecutor* executor, int64 size, const void* source) {
|
||||
if (size > std::numeric_limits<int32>::max()) {
|
||||
return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
|
||||
@ -133,12 +120,12 @@ StatusOr<gpu::InfeedBuffer*> GpuTransferManager::TransferBufferToInfeedInternal(
|
||||
return InternalError("Failed to obtain a stream");
|
||||
}
|
||||
|
||||
gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size);
|
||||
stream->ThenMemcpy(buffer->device_memory(), source, size);
|
||||
InfeedBuffer buffer(executor, size);
|
||||
stream->ThenMemcpy(buffer.device_memory(), source, size);
|
||||
|
||||
VLOG(2) << "Queued infeed data on stream " << stream;
|
||||
|
||||
return buffer;
|
||||
return std::move(buffer);
|
||||
}
|
||||
|
||||
static std::unique_ptr<Literal> ShapeTreeToLiteral(
|
||||
@ -191,17 +178,18 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
|
||||
// Give the tree of buffers to the outfeed mananger. The device will fill it
|
||||
// while we're waiting for it below.
|
||||
gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager();
|
||||
outfeed_manager->EnqueueOutfeedDestination(&outfeed_buffers);
|
||||
outfeed_manager->EnqueueDestination(&outfeed_buffers);
|
||||
|
||||
// Now turn the tree of buffers back into a literal.
|
||||
*literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() {
|
||||
return xla::MakeUnique<xla::GpuTransferManager>();
|
||||
return xla::MakeUnique<xla::gpu::GpuTransferManager>();
|
||||
}
|
||||
|
||||
static bool InitModule() {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// An implementation of the XLA GenericTransferManager that
|
||||
// handles GPU-specific infeed.
|
||||
@ -47,17 +49,18 @@ class GpuTransferManager : public GenericTransferManager {
|
||||
private:
|
||||
// Initiates the infeed data transfers. InfeedBuffer->Done() must be
|
||||
// called to clean up the memory allocated for InfeedBuffer.
|
||||
StatusOr<gpu::InfeedBuffer*> TransferBufferToInfeedInternal(
|
||||
StatusOr<InfeedBuffer> TransferBufferToInfeedInternal(
|
||||
se::StreamExecutor* executor, int64 size, const void* source);
|
||||
|
||||
// Enqueues infeed data buffers with the infeed manager after their
|
||||
// transfer completes.
|
||||
Status EnqueueBuffersToInfeed(se::StreamExecutor* executor,
|
||||
std::vector<gpu::InfeedBuffer*> buffers);
|
||||
ShapeTree<InfeedBuffer> buffers);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
|
||||
|
@ -15,76 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {}
|
||||
|
||||
void InfeedManager::Reset() {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
CHECK(dequeued_buffer_.empty());
|
||||
for (auto buffer : enqueued_buffer_) {
|
||||
buffer->Done();
|
||||
}
|
||||
enqueued_buffer_.clear();
|
||||
}
|
||||
|
||||
void InfeedManager::EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers) {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
bool was_empty = enqueued_buffer_.empty();
|
||||
for (gpu::InfeedBuffer* b : buffers) {
|
||||
enqueued_buffer_.push_back(b);
|
||||
}
|
||||
if (was_empty) {
|
||||
// This has the potential to suffer from the notified thread
|
||||
// immediately trying and failing to acquire mu_, but seems
|
||||
// preferable to the alternative of notifying outside the lock
|
||||
// on every enqueue.
|
||||
cv_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
InfeedBuffer* InfeedManager::BlockingDequeueBuffer() {
|
||||
bool became_empty = false;
|
||||
InfeedBuffer* current_buffer;
|
||||
{
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
while (enqueued_buffer_.empty()) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
current_buffer = enqueued_buffer_.front();
|
||||
enqueued_buffer_.pop_front();
|
||||
dequeued_buffer_.insert(current_buffer);
|
||||
if (enqueued_buffer_.empty()) {
|
||||
became_empty = true;
|
||||
}
|
||||
}
|
||||
if (became_empty) {
|
||||
for (const auto& callback : on_empty_callbacks_) {
|
||||
callback();
|
||||
}
|
||||
}
|
||||
return current_buffer;
|
||||
}
|
||||
|
||||
void InfeedManager::ReleaseBuffers(const std::vector<InfeedBuffer*>& buffers) {
|
||||
{
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
for (gpu::InfeedBuffer* b : buffers) {
|
||||
CHECK(ContainsKey(dequeued_buffer_, b));
|
||||
dequeued_buffer_.erase(b);
|
||||
}
|
||||
}
|
||||
for (gpu::InfeedBuffer* b : buffers) {
|
||||
b->Done();
|
||||
}
|
||||
}
|
||||
|
||||
se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
|
||||
tensorflow::mutex_lock l(host_to_device_stream_mu_);
|
||||
if (host_to_device_executor_ == nullptr) {
|
||||
host_to_device_executor_ = executor;
|
||||
host_to_device_stream_ = MakeUnique<se::Stream>(executor);
|
||||
@ -100,10 +37,6 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
|
||||
void InfeedManager::RegisterOnEmptyCallback(std::function<void()> callback) {
|
||||
on_empty_callbacks_.push_back(std::move(callback));
|
||||
}
|
||||
|
||||
InfeedManager* GetOrCreateInfeedManager() {
|
||||
static InfeedManager* manager = new InfeedManager;
|
||||
return manager;
|
||||
|
@ -20,12 +20,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
@ -47,90 +44,41 @@ namespace gpu {
|
||||
// the client. The client manages the memory of the buffer.
|
||||
class InfeedBuffer {
|
||||
public:
|
||||
InfeedBuffer() = default;
|
||||
InfeedBuffer(se::StreamExecutor* executor, int64 length)
|
||||
: executor_(executor), length_(length) {
|
||||
device_memory_ = executor_->AllocateArray<uint8>(length);
|
||||
CHECK(!device_memory_.is_null());
|
||||
: device_memory_(executor, executor->AllocateArray<uint8>(length)),
|
||||
length_(length) {
|
||||
CHECK(!device_memory_->is_null());
|
||||
}
|
||||
|
||||
~InfeedBuffer() { executor_->Deallocate(&device_memory_); }
|
||||
|
||||
int64 length() const { return length_; }
|
||||
|
||||
// Callback to signal that this buffer is consumed. This helps the
|
||||
// client to manage memory for the infeed buffers.
|
||||
void Done() { delete this; }
|
||||
|
||||
se::DeviceMemoryBase* device_memory() { return &device_memory_; }
|
||||
se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); }
|
||||
|
||||
private:
|
||||
se::StreamExecutor* executor_; // Not owned.
|
||||
const int64 length_;
|
||||
se::DeviceMemoryBase device_memory_;
|
||||
se::ScopedDeviceMemory<uint8> device_memory_;
|
||||
int64 length_;
|
||||
};
|
||||
|
||||
// Client-side class used to enqueue infeed buffers.
|
||||
class InfeedManager {
|
||||
class InfeedManager : public XfeedQueue<ShapeTree<InfeedBuffer>> {
|
||||
public:
|
||||
InfeedManager();
|
||||
|
||||
// Calls the completion callback for any enqueued buffers that have
|
||||
// not been dequeued by the runtime, and empties the infeed
|
||||
// queue. Reset may not be called while a runtime computation is
|
||||
// processing a dequeued buffer. The only safe way to ensure this
|
||||
// condition is to call Reset when no computation is taking place.
|
||||
void Reset();
|
||||
|
||||
// Adds a set of buffers to the infeed queue atomically. buffer->Done
|
||||
// will be called when the buffer will no longer be accessed by the
|
||||
// InfeedManager, either as a result of a call to Reset or because the
|
||||
// runtime has dequeued and used the buffer.
|
||||
void EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers);
|
||||
|
||||
// Blocks until the infeed queue is non-empty, then returns the
|
||||
// buffer at the head of the queue. Adds the current buffer to the
|
||||
// to-be released set.
|
||||
InfeedBuffer* BlockingDequeueBuffer();
|
||||
|
||||
// Releases a set of buffers from the to-be released set.
|
||||
void ReleaseBuffers(const std::vector<InfeedBuffer*>& buffers);
|
||||
|
||||
// Returns a cached stream associated with an executor. Allocates a
|
||||
// new stream on the first invocation. On subsequent invocations, if
|
||||
// the cached executor is not the same as the requested executor,
|
||||
// returns null.
|
||||
se::Stream* GetStream(se::StreamExecutor* executor);
|
||||
|
||||
// Registers a callback that will be called when 'enqueued_buffer_' becomes
|
||||
// empty.
|
||||
void RegisterOnEmptyCallback(std::function<void()> callback);
|
||||
|
||||
private:
|
||||
// TODO(b/30467474): Revisit if this mutex becomes a point of
|
||||
// contention.
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
// Condition variable that is signaled every time a buffer is
|
||||
// enqueued to an empty queue.
|
||||
tensorflow::condition_variable cv_;
|
||||
|
||||
// InfeedBuffer* queue contents are not owned, but buffer->Done must
|
||||
// be called when the buffer is no longer needed by the runtime.
|
||||
std::deque<InfeedBuffer*> enqueued_buffer_;
|
||||
|
||||
// Buffers that are dequeued and currently being processed by the
|
||||
// runtime. Not owned.
|
||||
tensorflow::gtl::FlatSet<const InfeedBuffer*> dequeued_buffer_;
|
||||
// Mutex for serializing the creation of host_to_device_stream_.
|
||||
tensorflow::mutex host_to_device_stream_mu_;
|
||||
|
||||
// Cached host to device stream for queuing infeed data.
|
||||
std::unique_ptr<se::Stream> host_to_device_stream_;
|
||||
std::unique_ptr<se::Stream> host_to_device_stream_
|
||||
GUARDED_BY(host_to_device_stream_mu_);
|
||||
|
||||
// Executor that the host_to_device_stream belongs to. Not owned.
|
||||
se::StreamExecutor* host_to_device_executor_;
|
||||
|
||||
// List of callbacks which will be called when 'enqueued_buffer_' becomes
|
||||
// empty.
|
||||
std::vector<std::function<void()>> on_empty_callbacks_;
|
||||
se::StreamExecutor* host_to_device_executor_ = nullptr;
|
||||
};
|
||||
|
||||
// Singleton creator-or-accessor: Returns the GPU infeed manager.
|
||||
|
@ -38,9 +38,10 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::DeviceMemoryBase data_address =
|
||||
buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
|
||||
InfeedManager* infeed_manager = GetOrCreateInfeedManager();
|
||||
std::vector<InfeedBuffer*> infeed_buffers;
|
||||
const Shape& data_shape =
|
||||
ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0);
|
||||
ShapeTree<InfeedBuffer> infeed_buffers =
|
||||
infeed_manager->BlockingGetNextDestination();
|
||||
if (ShapeUtil::IsTuple(data_shape)) {
|
||||
CHECK(!ShapeUtil::IsNestedTuple(data_shape));
|
||||
// Transfer the tuple elements first.
|
||||
@ -51,8 +52,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::DeviceMemoryBase tuple_element_address =
|
||||
buffer_allocations.GetDeviceAddress(tuple_element_buffer);
|
||||
|
||||
InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
|
||||
infeed_buffers.push_back(buffer);
|
||||
InfeedBuffer* buffer = infeed_buffers.mutable_element({i});
|
||||
stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()),
|
||||
buffer->length());
|
||||
tuple_element_addresses.push_back(tuple_element_address.opaque());
|
||||
@ -62,19 +62,17 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
stream->ThenMemcpy(&data_address, tuple_element_addresses.data(),
|
||||
host_size);
|
||||
} else {
|
||||
InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
|
||||
infeed_buffers.push_back(buffer);
|
||||
InfeedBuffer* buffer = infeed_buffers.mutable_element({});
|
||||
stream->ThenMemcpy(&data_address, *(buffer->device_memory()),
|
||||
buffer->length());
|
||||
}
|
||||
|
||||
// Construct top-level tuple of infeed containing the data and the token. Use
|
||||
// a nullptr for the token, it should never be dereferenced.
|
||||
std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr};
|
||||
void* infeed_addresses[] = {data_address.opaque(), nullptr};
|
||||
se::DeviceMemoryBase top_level_address =
|
||||
buffer_allocations.GetDeviceAddress(infeed_slices_.element({}));
|
||||
stream->ThenMemcpy(&top_level_address, infeed_addresses.data(),
|
||||
2 * sizeof(void*));
|
||||
stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*));
|
||||
|
||||
Status block_status = stream->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
@ -82,8 +80,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
stream, block_status.error_message().c_str());
|
||||
}
|
||||
|
||||
infeed_manager->ReleaseBuffers(infeed_buffers);
|
||||
|
||||
VLOG(2) << "Infeeding to GPU complete";
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -23,25 +23,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
void OutfeedManager::EnqueueOutfeedDestination(
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* buffers) {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
enqueued_buffers_.push_back(buffers);
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>*
|
||||
OutfeedManager::BlockingGetNextOutfeedDestination() {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
while (enqueued_buffers_.empty()) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* current_buffer =
|
||||
enqueued_buffers_.front();
|
||||
enqueued_buffers_.pop_front();
|
||||
return current_buffer;
|
||||
}
|
||||
|
||||
OutfeedManager* GetOrCreateOutfeedManager() {
|
||||
static auto* manager = new OutfeedManager;
|
||||
return manager;
|
||||
|
@ -16,10 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
@ -60,28 +58,7 @@ class OutfeedBuffer {
|
||||
|
||||
// Manages a thread-safe queue of buffers. The buffers are supposed to be
|
||||
// produced by the transfer manager and consumed by the device.
|
||||
class OutfeedManager {
|
||||
public:
|
||||
// Adds a tree of buffers to the queue. The individual buffers correspond to
|
||||
// the elements of a tuple and may be nullptr if the buffer is a tuple index
|
||||
// buffer.
|
||||
void EnqueueOutfeedDestination(
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* buffers);
|
||||
|
||||
// Blocks until the queue is non-empty, then returns the buffer at the head of
|
||||
// the queue.
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>*
|
||||
BlockingGetNextOutfeedDestination();
|
||||
|
||||
private:
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
// Condition variable that is signaled every time a buffer is enqueued.
|
||||
tensorflow::condition_variable cv_;
|
||||
|
||||
// The queue of trees of buffers. OutfeedBuffer* queue contents are not owned.
|
||||
std::deque<ShapeTree<std::unique_ptr<OutfeedBuffer>>*> enqueued_buffers_;
|
||||
};
|
||||
using OutfeedManager = XfeedQueue<ShapeTree<std::unique_ptr<OutfeedBuffer>>*>;
|
||||
|
||||
// Singleton creator-or-accessor: Returns the GPU outfeed manager.
|
||||
OutfeedManager* GetOrCreateOutfeedManager();
|
||||
|
@ -36,7 +36,7 @@ Status OutfeedThunk::ExecuteOnStream(
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
|
||||
outfeed_manager->BlockingGetNextOutfeedDestination();
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
||||
// Nothing to be done for empty tuples.
|
||||
if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) {
|
||||
|
89
tensorflow/compiler/xla/service/gpu/xfeed_queue.h
Normal file
89
tensorflow/compiler/xla/service/gpu/xfeed_queue.h
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// TODO(b/30467474) Once GPU outfeed implementation settles, consider
|
||||
// folding back the cpu and gpu outfeed implementations into a generic
|
||||
// one if possible.
|
||||
|
||||
// Manages a thread-safe queue of buffers.
|
||||
template <typename BufferType>
|
||||
class XfeedQueue {
|
||||
public:
|
||||
// Adds a tree of buffers to the queue. The individual buffers correspond to
|
||||
// the elements of a tuple and may be nullptr if the buffer is a tuple index
|
||||
// buffer.
|
||||
void EnqueueDestination(BufferType buffers) {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
enqueued_buffers_.push_back(std::move(buffers));
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
// Blocks until the queue is non-empty, then returns the buffer at the head of
|
||||
// the queue.
|
||||
BufferType BlockingGetNextDestination() {
|
||||
bool became_empty;
|
||||
BufferType current_buffer;
|
||||
{
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
while (enqueued_buffers_.empty()) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
current_buffer = std::move(enqueued_buffers_.front());
|
||||
enqueued_buffers_.pop_front();
|
||||
became_empty = enqueued_buffers_.empty();
|
||||
}
|
||||
if (became_empty) {
|
||||
for (const auto& callback : on_empty_callbacks_) {
|
||||
callback();
|
||||
}
|
||||
}
|
||||
return current_buffer;
|
||||
}
|
||||
|
||||
void RegisterOnEmptyCallback(std::function<void()> callback) {
|
||||
on_empty_callbacks_.push_back(std::move(callback));
|
||||
}
|
||||
|
||||
private:
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
// Condition variable that is signaled every time a buffer is enqueued.
|
||||
tensorflow::condition_variable cv_;
|
||||
|
||||
// The queue of trees of buffers. Buffer* queue contents are not owned.
|
||||
std::deque<BufferType> enqueued_buffers_ GUARDED_BY(mu_);
|
||||
|
||||
// List of callbacks which will be called when 'enqueued_buffers_' becomes
|
||||
// empty.
|
||||
std::vector<std::function<void()>> on_empty_callbacks_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
|
Loading…
Reference in New Issue
Block a user