[SE] Fix FIFO ordering of HostStream.

The HostStream implementation assumed that closures enqueued on a thread pool of size 1 ran in FIFO order. This is not a property of the TF thread pool implementation.

Change the implementation to use a single worker thread and an explicit std::queue<>.

PiperOrigin-RevId: 246929773
This commit is contained in:
Peter Hawkins 2019-05-06 17:32:45 -07:00 committed by TensorFlower Gardener
parent 91faf99ea7
commit bb405dfb93
5 changed files with 105 additions and 35 deletions

View File

@ -4,6 +4,7 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package_group(
name = "friends",
@ -111,3 +112,18 @@ cc_library(
],
alwayslink = True,
)
tf_cc_test(
name = "host_stream_test",
srcs = ["host_stream_test.cc"],
deps = [
":host_platform",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor:stream",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -17,49 +17,55 @@ limitations under the License.
// the HostExecutor implementation.
#include "tensorflow/stream_executor/host/host_stream.h"
#include "absl/synchronization/notification.h"
namespace stream_executor {
namespace host {
HostStream::HostStream()
: host_executor_(new port::ThreadPool(port::Env::Default(),
port::ThreadOptions(),
"host_executor", kExecutorThreads)) {}
HostStream::~HostStream() {}
bool HostStream::EnqueueTask(std::function<void()> task) {
struct NotifiedTask {
HostStream* stream;
std::function<void()> task;
void operator()() {
task();
// Destroy the task before unblocking its waiters, as BlockHostUntilDone()
// should guarantee that all tasks are destroyed.
task = std::function<void()>();
{
absl::MutexLock lock(&stream->mu_);
--stream->pending_tasks_;
}
stream->completion_condition_.SignalAll();
}
};
: thread_(port::Env::Default()->StartThread(
port::ThreadOptions(), "host_executor", [this]() { WorkLoop(); })) {}
HostStream::~HostStream() {
{
absl::MutexLock lock(&mu_);
++pending_tasks_;
work_queue_.push(nullptr);
}
host_executor_->Schedule(NotifiedTask{this, std::move(task)});
// thread_'s destructor blocks until the thread finishes running.
thread_.reset();
}
bool HostStream::EnqueueTask(std::function<void()> fn) {
CHECK(fn != nullptr);
absl::MutexLock lock(&mu_);
work_queue_.push(std::move(fn));
return true;
}
void HostStream::BlockUntilDone() {
absl::MutexLock lock(&mu_);
while (pending_tasks_ != 0) {
completion_condition_.Wait(&mu_);
bool HostStream::WorkAvailable() { return !work_queue_.empty(); }
void HostStream::WorkLoop() {
while (true) {
std::function<void()> fn;
{
absl::MutexLock lock(&mu_);
mu_.Await(absl::Condition(this, &HostStream::WorkAvailable));
fn = std::move(work_queue_.front());
work_queue_.pop();
}
if (!fn) {
return;
}
fn();
}
}
void HostStream::BlockUntilDone() {
absl::Notification done;
EnqueueTask([&done]() { done.Notify(); });
done.WaitForNotification();
}
} // namespace host
} // namespace stream_executor

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include <queue>
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
@ -41,14 +42,12 @@ class HostStream : public internal::StreamInterface {
void BlockUntilDone();
private:
// Use only one thread and own task queue to preserve FIFO ordering
// for the operations enqueued by any given stream.
static const int kExecutorThreads = 1;
std::unique_ptr<port::ThreadPool> host_executor_;
bool WorkAvailable() EXCLUSIVE_LOCKS_REQUIRED(mu_);
void WorkLoop();
absl::Mutex mu_;
int pending_tasks_ GUARDED_BY(mu_) = 0;
absl::CondVar completion_condition_;
std::queue<std::function<void()>> work_queue_ GUARDED_BY(mu_);
std::unique_ptr<port::Thread> thread_;
};
} // namespace host

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace se = stream_executor;
TEST(HostStream, EnforcesFIFOOrder) {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
se::StreamExecutor* executor = platform->ExecutorForDevice(0).ValueOrDie();
se::Stream stream(executor);
stream.Init();
absl::Mutex mu;
int expected = 0;
bool ok = true;
for (int i = 0; i < 2000; ++i) {
stream.ThenDoHostCallback([i, &mu, &expected, &ok]() {
absl::MutexLock lock(&mu);
if (expected != i) {
ok = false;
}
++expected;
});
}
TF_ASSERT_OK(stream.BlockHostUntilDone());
absl::MutexLock lock(&mu);
EXPECT_TRUE(ok);
}

View File

@ -23,6 +23,7 @@ limitations under the License.
namespace stream_executor {
namespace port {
using tensorflow::Thread;
using tensorflow::thread::ThreadPool;
} // namespace port