Add an unbounded work queue based on the existing UnboundedThreadPool
implementation. This change adds `UnboundedWorkQueue` to tensorflow/core/platform for general use in TensorFlow runtime. The implementation is basically the same as the existing tf.data unbounded thread pool. After this change, `UnboundedThreadPool` is a thin wrapper around `UnboundedWorkQueue`. PiperOrigin-RevId: 259668662
This commit is contained in:
parent
805b28132e
commit
2a4b5a3f23
@ -626,6 +626,38 @@ filegroup(
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "platform_unbounded_work_queue",
|
||||
srcs = tf_platform_srcs([
|
||||
"unbounded_work_queue.cc",
|
||||
]) + tf_platform_hdrs([
|
||||
"unbounded_work_queue.h",
|
||||
]),
|
||||
hdrs = ["platform/unbounded_work_queue.h"],
|
||||
deps = [
|
||||
":core_cpu_internal",
|
||||
":framework",
|
||||
":lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "platform_unbounded_work_queue_test",
|
||||
srcs = ["platform/unbounded_work_queue_test.cc"],
|
||||
deps = [
|
||||
":framework",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":lib_test_internal",
|
||||
":platform_unbounded_work_queue",
|
||||
":protos_all_cc",
|
||||
":test",
|
||||
":test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
# Headers that are not exported as part of ":lib".
|
||||
filegroup(
|
||||
name = "platform_other_internal_hdrs",
|
||||
|
@ -180,6 +180,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:platform_unbounded_work_queue",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
@ -16,8 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -30,7 +31,7 @@ class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
|
||||
|
||||
std::unique_ptr<Thread> StartThread(const string& name,
|
||||
std::function<void()> fn) override {
|
||||
return pool_->RunOnPooledThread(std::move(fn));
|
||||
return pool_->ScheduleOnWorkQueue(std::move(fn));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -52,8 +53,7 @@ class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
|
||||
// NOTE: The `Thread` destructor is expected to "join" the created thread,
|
||||
// but the physical thread may continue to execute after the work for this
|
||||
// thread is complete. We simulate this by waiting on a notification that
|
||||
// the `CachedThreadFunc` will notify when the thread's work function is
|
||||
// complete.
|
||||
// the thread's work function will notify when it is complete.
|
||||
join_notification_->WaitForNotification();
|
||||
}
|
||||
|
||||
@ -61,96 +61,25 @@ class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
|
||||
std::shared_ptr<Notification> join_notification_;
|
||||
};
|
||||
|
||||
UnboundedThreadPool::~UnboundedThreadPool() {
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
// Wake up all `CachedThreadFunc` threads and cause them to terminate before
|
||||
// joining them when `threads_` is cleared.
|
||||
cancelled_ = true;
|
||||
work_queue_cv_.notify_all();
|
||||
if (!work_queue_.empty()) {
|
||||
LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
|
||||
<< "deleted with pending work in its queue. This may indicate "
|
||||
<< "a potential use-after-free bug.";
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
// Clear the list of pooled threads, which will eventually terminate due to
|
||||
// the previous notification.
|
||||
//
|
||||
// NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
|
||||
// no subsequent calls to `this->StartThread()` should be issued after the
|
||||
// destructor starts.
|
||||
thread_pool_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
|
||||
return std::make_shared<LogicalThreadFactory>(this);
|
||||
}
|
||||
|
||||
size_t UnboundedThreadPool::size() {
|
||||
tf_shared_lock l(thread_pool_mu_);
|
||||
return thread_pool_.size();
|
||||
namespace {
|
||||
void WorkQueueFunc(const std::function<void()>& fn,
|
||||
std::shared_ptr<Notification> notification) {
|
||||
fn();
|
||||
notification->Notify();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
|
||||
std::unique_ptr<Thread> UnboundedThreadPool::ScheduleOnWorkQueue(
|
||||
std::function<void()> fn) {
|
||||
auto join_notification = std::make_shared<Notification>();
|
||||
bool all_threads_busy;
|
||||
{
|
||||
// Enqueue a work item for the new thread's function, and wake up a
|
||||
// cached thread to process it.
|
||||
mutex_lock l(work_queue_mu_);
|
||||
work_queue_.push_back({std::move(fn), join_notification});
|
||||
work_queue_cv_.notify_one();
|
||||
// NOTE: The queue may be non-empty, so we must account for queued work when
|
||||
// considering how many threads are free.
|
||||
all_threads_busy = work_queue_.size() > num_idle_threads_;
|
||||
}
|
||||
|
||||
if (all_threads_busy) {
|
||||
// Spawn a new physical thread to process the given function.
|
||||
// NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
|
||||
// at the beginning of its work loop.
|
||||
Thread* new_thread = env_->StartThread(
|
||||
{}, thread_name_,
|
||||
std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
|
||||
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
thread_pool_.emplace_back(new_thread);
|
||||
}
|
||||
|
||||
unbounded_work_queue_.Schedule(
|
||||
std::bind(&WorkQueueFunc, std::move(fn), join_notification));
|
||||
return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
|
||||
}
|
||||
|
||||
void UnboundedThreadPool::PooledThreadFunc() {
|
||||
while (true) {
|
||||
WorkItem work_item;
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
++num_idle_threads_;
|
||||
while (!cancelled_ && work_queue_.empty()) {
|
||||
// Wait for a new work function to be submitted, or the cache to be
|
||||
// destroyed.
|
||||
work_queue_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
work_item = std::move(work_queue_.front());
|
||||
work_queue_.pop_front();
|
||||
--num_idle_threads_;
|
||||
}
|
||||
|
||||
work_item.work_function();
|
||||
|
||||
// Notify any thread that has "joined" the cached thread for this work item.
|
||||
work_item.done_notification->Notify();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -20,55 +20,33 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/thread_factory.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a
|
||||
// potentially large number of "logical" threads onto a smaller number of
|
||||
// "physical" threads. The multiplexing is achieved by maintaining an internal
|
||||
// pool of long-running "physical" threads that are used to execute the
|
||||
// "logical" threads. Like a regular thread, a "logical" thread may block on
|
||||
// other threads, and the size of the pool will increase to ensure that progress
|
||||
// is made. This mechanism is recommended in situations where short-lived
|
||||
// threads are created repeatedly, to avoid the overhead and memory
|
||||
// fragmentation that can result from excessive thread creation.
|
||||
// "physical" threads. The multiplexing is achieved by using an
|
||||
// `UnboundedWorkQueue`.
|
||||
class UnboundedThreadPool {
|
||||
public:
|
||||
UnboundedThreadPool(Env* env, const string& thread_name)
|
||||
: env_(env), thread_name_(thread_name) {}
|
||||
~UnboundedThreadPool();
|
||||
: unbounded_work_queue_(env, thread_name) {}
|
||||
~UnboundedThreadPool() = default;
|
||||
|
||||
// Returns an implementation of `ThreadFactory` that can be used to create
|
||||
// logical threads in this pool.
|
||||
std::shared_ptr<ThreadFactory> get_thread_factory();
|
||||
|
||||
// Returns the current number of threads in this pool.
|
||||
size_t size();
|
||||
|
||||
private:
|
||||
class LogicalThreadFactory;
|
||||
class LogicalThreadWrapper;
|
||||
struct WorkItem {
|
||||
std::function<void()> work_function;
|
||||
std::shared_ptr<Notification> done_notification;
|
||||
};
|
||||
|
||||
std::unique_ptr<Thread> RunOnPooledThread(std::function<void()> fn);
|
||||
void PooledThreadFunc();
|
||||
std::unique_ptr<Thread> ScheduleOnWorkQueue(std::function<void()> fn);
|
||||
|
||||
Env* const env_; // Not owned.
|
||||
const string thread_name_;
|
||||
mutex work_queue_mu_;
|
||||
condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
|
||||
size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
|
||||
bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
|
||||
std::deque<WorkItem> work_queue_ GUARDED_BY(work_queue_mu_);
|
||||
mutex thread_pool_mu_;
|
||||
std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
|
||||
UnboundedWorkQueue unbounded_work_queue_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -23,59 +23,6 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
TEST(UnboundedThreadPool, SingleThread) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create a thread that updates a variable, and ensure that it runs to
|
||||
// completion.
|
||||
std::atomic<int> i(0);
|
||||
auto thread = thread_factory->StartThread("", [&i]() { ++i; });
|
||||
thread.reset();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(1, i);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, MultipleThreads) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create ten threads that update a variable, and ensure that they all run
|
||||
// to completion.
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
const int kNumThreadsToCreate = 10;
|
||||
std::atomic<int> i(0);
|
||||
for (int j = 0; j < kNumThreadsToCreate; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&i]() { ++i; }));
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create 1000 threads that sleep for a random period of time then update a
|
||||
// variable, and ensure that they all run to completion.
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
const int kNumThreadsToCreate = 1000;
|
||||
std::atomic<int> i(0);
|
||||
for (int j = 0; j < kNumThreadsToCreate; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&i]() {
|
||||
Env::Default()->SleepForMicroseconds(random::New64() % 10);
|
||||
++i;
|
||||
}));
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
@ -97,7 +44,6 @@ TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
@ -108,9 +54,7 @@ TEST(UnboundedThreadPool, MultipleBlockingThreads) {
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
|
||||
// Create multiple waves (with increasing sizes) of threads that all block
|
||||
// before returning, and
|
||||
// ensure that we create the appropriate number of threads and terminate
|
||||
// correctly.
|
||||
// before returning, and ensure that we terminate correctly.
|
||||
std::vector<int> round_sizes = {5, 10, 15, 20};
|
||||
|
||||
for (const int round_size : round_sizes) {
|
||||
@ -129,10 +73,6 @@ TEST(UnboundedThreadPool, MultipleBlockingThreads) {
|
||||
// wave is increasing, we should have at least that number of threads in the
|
||||
// pool.
|
||||
bc.Wait();
|
||||
// NOTE: There is a benign race between a new round starting and the
|
||||
// physical threads from the previous round returning to the pool, so we may
|
||||
// create more threads than the round_size.
|
||||
EXPECT_GE(pool.size(), round_size);
|
||||
n.Notify();
|
||||
threads.clear();
|
||||
}
|
||||
|
101
tensorflow/core/platform/default/unbounded_work_queue.cc
Normal file
101
tensorflow/core/platform/default/unbounded_work_queue.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
UnboundedWorkQueue::UnboundedWorkQueue(Env* env, const string& thread_name)
|
||||
: env_(env), thread_name_(thread_name) {}
|
||||
|
||||
UnboundedWorkQueue::~UnboundedWorkQueue() {
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
// Wake up all `PooledThreadFunc` threads and cause them to terminate before
|
||||
// joining them when `threads_` is cleared.
|
||||
cancelled_ = true;
|
||||
work_queue_cv_.notify_all();
|
||||
if (!work_queue_.empty()) {
|
||||
LOG(ERROR) << "UnboundedWorkQueue named \"" << thread_name_ << "\" was "
|
||||
<< "deleted with pending work in its queue. This may indicate "
|
||||
<< "a potential use-after-free bug.";
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
// Clear the list of pooled threads, which will eventually terminate due to
|
||||
// the previous notification.
|
||||
//
|
||||
// NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
|
||||
// no subsequent calls to `this->StartThread()` should be issued after the
|
||||
// destructor starts.
|
||||
thread_pool_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void UnboundedWorkQueue::Schedule(WorkFunction fn) {
|
||||
bool all_threads_busy;
|
||||
{
|
||||
// Enqueue a work item for the new thread's function, and wake up a
|
||||
// cached thread to process it.
|
||||
mutex_lock l(work_queue_mu_);
|
||||
work_queue_.push_back(std::move(fn));
|
||||
work_queue_cv_.notify_one();
|
||||
// NOTE: The queue may be non-empty, so we must account for queued work when
|
||||
// considering how many threads are free.
|
||||
all_threads_busy = work_queue_.size() > num_idle_threads_;
|
||||
}
|
||||
|
||||
if (all_threads_busy) {
|
||||
// Spawn a new physical thread to process the given function.
|
||||
// NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
|
||||
// at the beginning of its work loop.
|
||||
Thread* new_thread =
|
||||
env_->StartThread({}, thread_name_, [this]() { PooledThreadFunc(); });
|
||||
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
thread_pool_.emplace_back(new_thread);
|
||||
}
|
||||
}
|
||||
|
||||
void UnboundedWorkQueue::PooledThreadFunc() {
|
||||
while (true) {
|
||||
WorkFunction fn;
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
++num_idle_threads_;
|
||||
while (!cancelled_ && work_queue_.empty()) {
|
||||
// Wait for a new work function to be submitted, or the cache to be
|
||||
// destroyed.
|
||||
work_queue_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
fn = std::move(work_queue_.front());
|
||||
work_queue_.pop_front();
|
||||
--num_idle_threads_;
|
||||
}
|
||||
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/core/platform/default/unbounded_work_queue.h
Normal file
65
tensorflow/core/platform/default/unbounded_work_queue.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// An `UnboundedWorkQueue` provides a mechanism for temporally multiplexing a
|
||||
// potentially large number of "logical" threads onto a smaller number of
|
||||
// "physical" threads. The multiplexing is achieved by maintaining an internal
|
||||
// pool of long-running "physical" threads that are used to execute the
|
||||
// "logical" threads. Like a regular thread, a "logical" thread may block on
|
||||
// other threads, and the size of the pool will increase to ensure that progress
|
||||
// is made. This mechanism is recommended in situations where short-lived
|
||||
// threads are created repeatedly, to avoid the overhead and memory
|
||||
// fragmentation that can result from excessive thread creation.
|
||||
class UnboundedWorkQueue {
|
||||
public:
|
||||
UnboundedWorkQueue(Env* env, const string& thread_name);
|
||||
~UnboundedWorkQueue();
|
||||
|
||||
using WorkFunction = std::function<void()>;
|
||||
|
||||
// Schedule `fn` on a thread. `fn` may perform blocking work, so if all the
|
||||
// existing threads are blocked or busy, this may spawn a new thread which
|
||||
// will be added to the thread pool managed by this work queue.
|
||||
void Schedule(WorkFunction fn);
|
||||
|
||||
private:
|
||||
void PooledThreadFunc();
|
||||
|
||||
Env* const env_; // Not owned.
|
||||
const string thread_name_;
|
||||
mutex work_queue_mu_;
|
||||
condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
|
||||
size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
|
||||
bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
|
||||
std::deque<WorkFunction> work_queue_ GUARDED_BY(work_queue_mu_);
|
||||
mutex thread_pool_mu_;
|
||||
std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_
|
33
tensorflow/core/platform/unbounded_work_queue.h
Normal file
33
tensorflow/core/platform/unbounded_work_queue.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
|
||||
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
|
||||
// An `UnboundedWorkQueue` feeds potentially-blocking work into a thread-pool
|
||||
// whose size automatically increases with demand.
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/core/platform/google/unbounded_work_queue.h"
|
||||
#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
|
||||
defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_WINDOWS)
|
||||
#include "tensorflow/core/platform/default/unbounded_work_queue.h"
|
||||
#else
|
||||
#error Define the appropriate PLATFORM_<foo> macro for this platform
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_
|
104
tensorflow/core/platform/unbounded_work_queue_test.cc
Normal file
104
tensorflow/core/platform/unbounded_work_queue_test.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class UnboundedWorkQueueTest : public ::testing::Test {
|
||||
protected:
|
||||
UnboundedWorkQueueTest()
|
||||
: work_queue_(
|
||||
absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test")) {}
|
||||
~UnboundedWorkQueueTest() override = default;
|
||||
|
||||
void RunMultipleCopiesOfClosure(const int num_closures,
|
||||
std::function<void()> fn) {
|
||||
for (int i = 0; i < num_closures; ++i) {
|
||||
work_queue_->Schedule([this, fn]() {
|
||||
fn();
|
||||
mutex_lock l(mu_);
|
||||
++closure_count_;
|
||||
cond_var_.notify_all();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void BlockUntilClosuresDone(const int num_closures) {
|
||||
mutex_lock l(mu_);
|
||||
while (closure_count_ < num_closures) {
|
||||
cond_var_.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
void ResetQueue() { work_queue_.reset(); }
|
||||
|
||||
int NumClosuresExecuted() {
|
||||
mutex_lock l(mu_);
|
||||
return closure_count_;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int closure_count_ GUARDED_BY(mu_) = 0;
|
||||
condition_variable cond_var_;
|
||||
std::unique_ptr<UnboundedWorkQueue> work_queue_;
|
||||
};
|
||||
|
||||
TEST_F(UnboundedWorkQueueTest, SingleClosure) {
|
||||
constexpr int num_closures = 1;
|
||||
RunMultipleCopiesOfClosure(num_closures, []() {});
|
||||
BlockUntilClosuresDone(num_closures);
|
||||
}
|
||||
|
||||
TEST_F(UnboundedWorkQueueTest, MultipleClosures) {
|
||||
constexpr int num_closures = 10;
|
||||
RunMultipleCopiesOfClosure(num_closures, []() {});
|
||||
BlockUntilClosuresDone(num_closures);
|
||||
}
|
||||
|
||||
TEST_F(UnboundedWorkQueueTest, MultipleClosuresSleepingRandomly) {
|
||||
constexpr int num_closures = 1000;
|
||||
RunMultipleCopiesOfClosure(num_closures, []() {
|
||||
Env::Default()->SleepForMicroseconds(random::New64() % 10);
|
||||
});
|
||||
BlockUntilClosuresDone(num_closures);
|
||||
}
|
||||
|
||||
TEST_F(UnboundedWorkQueueTest, NestedClosures) {
|
||||
constexpr int num_closures = 10;
|
||||
// Run `num_closures` closures, each of which runs `num_closures` closures.
|
||||
RunMultipleCopiesOfClosure(num_closures, [this]() {
|
||||
RunMultipleCopiesOfClosure(num_closures, []() {});
|
||||
});
|
||||
BlockUntilClosuresDone(num_closures * num_closures + num_closures);
|
||||
}
|
||||
|
||||
TEST_F(UnboundedWorkQueueTest, RacyDestructor) {
|
||||
constexpr int num_closures = 100;
|
||||
// Run `num_closures` closures, then delete `work_queue_`.
|
||||
RunMultipleCopiesOfClosure(num_closures, []() {});
|
||||
ResetQueue();
|
||||
EXPECT_LE(NumClosuresExecuted(), num_closures);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user