[TF:XLA] Delete unused producer consumer queue class.
PiperOrigin-RevId: 221265027
This commit is contained in:
parent
8cbd3301ec
commit
a466bbdb04
@ -544,25 +544,6 @@ cc_library(
|
|||||||
hdrs = ["union_find.h"],
|
hdrs = ["union_find.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "producer_consumer_queue",
|
|
||||||
hdrs = ["producer_consumer_queue.h"],
|
|
||||||
deps = ["//tensorflow/core:lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "producer_consumer_queue_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["producer_consumer_queue_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":producer_consumer_queue",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "deadness_analysis_test",
|
name = "deadness_analysis_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -1,132 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
|
||||||
#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
|
||||||
|
|
||||||
#include <deque>
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
// A thread-safe, first-in-first-out queue.
|
|
||||||
template <typename T>
|
|
||||||
class ProducerConsumerQueue {
|
|
||||||
public:
|
|
||||||
ProducerConsumerQueue()
|
|
||||||
: capacity_(std::numeric_limits<std::size_t>::max()) {}
|
|
||||||
~ProducerConsumerQueue() = default;
|
|
||||||
|
|
||||||
// Wait until the queue is non-full, then append a copy of v.
|
|
||||||
void Put(const T &v);
|
|
||||||
|
|
||||||
// Wait until the queue is non-empty, then remove and return the head value.
|
|
||||||
T Get();
|
|
||||||
|
|
||||||
// If the queue is non-empty, remove the head value, placing it in *pv, and
|
|
||||||
// return true; otherwise return false.
|
|
||||||
bool TryGet(T *pv);
|
|
||||||
|
|
||||||
// Set the capacity of the queue; the queue is full whenever count() >=
|
|
||||||
// capacity(). The initial value is the maximum size_t. Requires size > 0.
|
|
||||||
void set_capacity(std::size_t size);
|
|
||||||
|
|
||||||
// Return the capacity of the queue.
|
|
||||||
std::size_t capacity() const;
|
|
||||||
|
|
||||||
// Return the number of elements in the queue.
|
|
||||||
std::size_t count() const;
|
|
||||||
|
|
||||||
// Implementation details follow. Clients should ignore.
|
|
||||||
private:
|
|
||||||
mutable tensorflow::mutex mu_; // protects all fields below
|
|
||||||
tensorflow::condition_variable non_empty_ GUARDED_BY(mu_);
|
|
||||||
tensorflow::condition_variable non_full_ GUARDED_BY(mu_);
|
|
||||||
std::size_t capacity_ GUARDED_BY(mu_);
|
|
||||||
std::deque<T> queue_ GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue);
|
|
||||||
};
|
|
||||||
|
|
||||||
// ------------------------------------------------------
|
|
||||||
// Implementation details follow. Clients should ignore.
|
|
||||||
|
|
||||||
// Wait until the queue is non-full, then append a copy of v.
|
|
||||||
template <typename T>
|
|
||||||
void ProducerConsumerQueue<T>::Put(const T &v) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
while (queue_.size() >= capacity_) {
|
|
||||||
non_full_.wait(lock);
|
|
||||||
}
|
|
||||||
queue_.push_back(v);
|
|
||||||
non_empty_.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait until the queue is non-empty, then remove and return the head value.
|
|
||||||
template <typename T>
|
|
||||||
T ProducerConsumerQueue<T>::Get() {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
while (queue_.empty()) {
|
|
||||||
non_empty_.wait(lock);
|
|
||||||
}
|
|
||||||
non_full_.notify_one();
|
|
||||||
T result_value = queue_.front();
|
|
||||||
queue_.pop_front();
|
|
||||||
return result_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the queue is non-empty, remove the head value, placing it in *pv, and
|
|
||||||
// return true; otherwise return false.
|
|
||||||
template <typename T>
|
|
||||||
bool ProducerConsumerQueue<T>::TryGet(T *pv) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
bool got_element = !queue_.empty();
|
|
||||||
if (got_element) {
|
|
||||||
non_full_.notify_one();
|
|
||||||
*pv = queue_.front();
|
|
||||||
queue_.pop_front();
|
|
||||||
}
|
|
||||||
return got_element;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the capacity of the queue; the queue is full whenever count() >=
|
|
||||||
// capacity(). The initial value is the maximum size_t. Requires size > 0.
|
|
||||||
template <typename T>
|
|
||||||
void ProducerConsumerQueue<T>::set_capacity(std::size_t size) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
CHECK_NE(size, 0);
|
|
||||||
capacity_ = size;
|
|
||||||
non_full_.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the capacity of the queue.
|
|
||||||
template <typename T>
|
|
||||||
std::size_t ProducerConsumerQueue<T>::capacity() const {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
std::size_t max_elements = capacity_;
|
|
||||||
return max_elements;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the number of elements in the queue.
|
|
||||||
template <typename T>
|
|
||||||
std::size_t ProducerConsumerQueue<T>::count() const {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
std::size_t num_elements = queue_.size();
|
|
||||||
return num_elements;
|
|
||||||
}
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
|
@ -1,139 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/producer_consumer_queue.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
typedef ProducerConsumerQueue<int> IntQueue;
|
|
||||||
|
|
||||||
// Insert integers between low inclusive and high exclusive into q.
|
|
||||||
void PushRange(IntQueue *q, int low, int high) {
|
|
||||||
while (low != high) {
|
|
||||||
q->Put(low);
|
|
||||||
VLOG(2) << "Pushing " << low;
|
|
||||||
++low;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push the numbers between 0 and 999 inclusive from several threads in the
|
|
||||||
// pool.
|
|
||||||
void PushRanges(IntQueue *queue, thread::ThreadPool *pool) {
|
|
||||||
VLOG(1) << "Adding 20-36";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 20, 36); });
|
|
||||||
VLOG(1) << "Adding 7-20";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 7, 20); });
|
|
||||||
VLOG(1) << "Adding 36-501";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 36, 501); });
|
|
||||||
VLOG(1) << "Adding 501-1000";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 501, 1000); });
|
|
||||||
VLOG(1) << "Adding 0-5";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 0, 5); });
|
|
||||||
VLOG(1) << "Adding 5-7";
|
|
||||||
pool->Schedule([queue] { PushRange(queue, 5, 7); });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pop elements from queue using Get(). Make sure that exactly <high> elements
|
|
||||||
// were present and their values are all integers between 0 and high-1
|
|
||||||
// inclusive.
|
|
||||||
void GetRange(IntQueue *queue, int high) {
|
|
||||||
VLOG(1) << "Testing Wait";
|
|
||||||
std::vector<int> results;
|
|
||||||
for (int i = 0; i != high; ++i) {
|
|
||||||
int r = queue->Get();
|
|
||||||
VLOG(2) << "Waited and got " << r;
|
|
||||||
results.push_back(r);
|
|
||||||
}
|
|
||||||
CHECK_EQ(queue->count(), 0);
|
|
||||||
std::sort(results.begin(), results.end());
|
|
||||||
for (int i = 0; i != high; ++i) {
|
|
||||||
CHECK(results[i] == i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pop elements from queue using TryGet(). Make sure that exactly <high>
|
|
||||||
// elements were present and their values are all integers between 0 and high-1
|
|
||||||
// inclusive.
|
|
||||||
void TryGetRange(IntQueue *queue, int high) {
|
|
||||||
std::vector<int> results;
|
|
||||||
// Give up if we don't get all the elements back from the queue
|
|
||||||
// in 10 seconds.
|
|
||||||
int timeout = 10;
|
|
||||||
int r;
|
|
||||||
for (int i = 0; i != high; ++i) {
|
|
||||||
while (!queue->TryGet(&r)) {
|
|
||||||
if (!timeout--) {
|
|
||||||
LOG(FATAL) << "Can't find all elements in the queue";
|
|
||||||
}
|
|
||||||
VLOG(1) << "Sleeping for a second...";
|
|
||||||
sleep(1);
|
|
||||||
}
|
|
||||||
VLOG(2) << "Popped " << r;
|
|
||||||
results.push_back(r);
|
|
||||||
}
|
|
||||||
CHECK_EQ(queue->count(), 0);
|
|
||||||
CHECK(!queue->TryGet(&r));
|
|
||||||
std::sort(results.begin(), results.end());
|
|
||||||
for (int i = 0; i != high; ++i) {
|
|
||||||
CHECK_EQ(i, results[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const int kNumThreads = 15;
|
|
||||||
|
|
||||||
TEST(ProducerConsumerQueue, GetRange) {
|
|
||||||
IntQueue queue;
|
|
||||||
{
|
|
||||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
|
||||||
PushRanges(&queue, &pool);
|
|
||||||
}
|
|
||||||
GetRange(&queue, 1000);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProducerConsumerQueue, TryGetRange) {
|
|
||||||
IntQueue queue;
|
|
||||||
{
|
|
||||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
|
||||||
PushRanges(&queue, &pool);
|
|
||||||
}
|
|
||||||
TryGetRange(&queue, 1000);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProducerConsumerQueue, ParallelGetRange) {
|
|
||||||
IntQueue queue;
|
|
||||||
{
|
|
||||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
|
||||||
pool.Schedule([&queue] { GetRange(&queue, 1000); });
|
|
||||||
PushRanges(&queue, &pool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProducerConsumerQueue, ParallelTryGetRange) {
|
|
||||||
IntQueue queue;
|
|
||||||
{
|
|
||||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
|
||||||
pool.Schedule([&queue] { TryGetRange(&queue, 1000); });
|
|
||||||
PushRanges(&queue, &pool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
Loading…
Reference in New Issue
Block a user