[tf.data] Add an unbounded thread pool to iterator resources.

The previous implementation of many core `tf.data` transformations
(e.g. `Dataset.prefetch()`) would create one or more threads each time
an iterator over those datasets is created
(e.g. `ds.prefetch(N).repeat(100)` would create and destroy 100
threads). In addition to the overhead of thread creation, this
interacts poorly with some malloc implementations, and can contribute
to memory fragmentation.

The new implementation maintains an unbounded pool of physical threads
in each iterator (or `MultiDeviceIterator`) resource, and returns logical
"threads" to that pool when their work is complete instead of exiting
from them.

PiperOrigin-RevId: 236413014
This commit is contained in:
Derek Murray 2019-03-01 18:22:18 -08:00 committed by TensorFlower Gardener
parent 4006467c5f
commit 70da1fe25d
17 changed files with 556 additions and 80 deletions

View File

@ -925,6 +925,7 @@ tf_cuda_library(
"framework/tensor_slice.h",
"framework/tensor_types.h",
"framework/tensor_util.h",
"framework/thread_factory.h",
"framework/tracking_allocator.h",
"framework/type_index.h",
"framework/type_traits.h",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/thread_factory.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
@ -287,7 +289,8 @@ class IteratorContext {
model(ctx->model()),
runner(*(ctx->runner())),
runner_threadpool_size(ctx->runner_threadpool_size()),
stats_aggregator(ctx->stats_aggregator()) {}
stats_aggregator(ctx->stats_aggregator()),
thread_factory(ctx->thread_factory()) {}
explicit Params(OpKernelContext* ctx)
: env(ctx->env()),
@ -338,6 +341,10 @@ class IteratorContext {
// The `StatsAggregator` object to record statistics about the iterator.
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// A `ThreadFactory` for creating threads used by iterators to perform
// blocking work.
std::shared_ptr<ThreadFactory> thread_factory = nullptr;
};
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@ -374,6 +381,20 @@ class IteratorContext {
return &params_.runner;
}
const std::shared_ptr<ThreadFactory>& thread_factory() {
return params_.thread_factory;
}
std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) {
if (params_.thread_factory) {
return params_.thread_factory->StartThread(name, std::move(fn));
} else {
return absl::WrapUnique(
Env::Default()->StartThread({}, name, std::move(fn)));
}
}
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
std::shared_ptr<StatsAggregator> stats_aggregator() {

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
#include <functional>
#include <memory>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class Thread;
// Virtual interface for an object that creates threads.
class ThreadFactory {
public:
virtual ~ThreadFactory() {}
// Runs `fn` asynchronously in a different thread. `fn` may block.
//
// NOTE: The caller is responsible for ensuring that this `ThreadFactory`
// outlives the returned `Thread`.
virtual std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_

View File

@ -129,6 +129,29 @@ tf_cc_test(
],
)
cc_library(
name = "unbounded_thread_pool",
srcs = ["unbounded_thread_pool.cc"],
hdrs = ["unbounded_thread_pool.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "unbounded_thread_pool_test",
srcs = ["unbounded_thread_pool_test.cc"],
deps = [
":unbounded_thread_pool",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
@ -595,6 +618,7 @@ tf_kernel_library(
deps = [
":dataset_utils",
":optional_ops",
":unbounded_thread_pool",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@ -612,6 +636,7 @@ tf_kernel_library(
srcs = ["multi_device_iterator_ops.cc"],
deps = [
":dataset_utils",
":unbounded_thread_pool",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -292,10 +292,10 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
for (size_t i = 0, num_inputs = dataset()->inputs_.size();
i < num_inputs; ++i) {
threads[i].result = absl::make_unique<InvocationResult>();
threads[i].thread.reset(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_merge_", i),
threads[i].thread = ctx->StartThread(
strings::StrCat("tf_data_merge_", i),
std::bind(&ChooseFastestIterator::RunnerThread, this, ctx,
threads[i].result.get(), i)));
threads[i].result.get(), i));
}
return threads;
}

View File

@ -514,9 +514,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_map_and_batch",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
runner_thread_ = ctx->StartThread(
"tf_data_map_and_batch",
std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -926,8 +926,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
if (!new_ctx) {
new_ctx = std::make_shared<IteratorContext>(*ctx);
}
workers_[i]->threads.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
workers_[i]->threads.emplace_back(ctx->StartThread(
strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
[this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); }));
VLOG(3) << "Worker " << i << ", " << j << " successfully started.";
}
@ -936,9 +936,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
if (!new_ctx) {
new_ctx = std::make_shared<IteratorContext>(*ctx);
}
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_numa_map_and_batch",
[this, new_ctx] { RunnerThread(new_ctx); }));
runner_thread_ =
ctx->StartThread("tf_data_numa_map_and_batch",
[this, new_ctx] { RunnerThread(new_ctx); });
}
VLOG(3) << "All workers & runner thread started.";
return Status::OK();

View File

@ -493,8 +493,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
worker_threads_.emplace_back(ctx->StartThread(
strings::StrCat("tf_data_parallel_interleave_worker_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
@ -592,8 +592,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
worker_threads_.push_back(ctx->StartThread(
strings::StrCat("tf_data_parallel_interleave_worker_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -51,14 +53,15 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
class IteratorResource : public ResourceBase {
public:
IteratorResource(const DataTypeVector& output_dtypes,
IteratorResource(Env* env, const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes,
const int /*unused: graph_def_version*/,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib)
: device_mgr_(std::move(device_mgr)),
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
device_mgr_(std::move(device_mgr)),
iterator_state_(std::make_shared<State>(
std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)),
output_dtypes_(output_dtypes),
@ -77,6 +80,7 @@ class IteratorResource : public ResourceBase {
params.function_handle_cache =
captured_state->function_handle_cache.get();
params.resource_mgr = &captured_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
return captured_state->iterator->GetNext(
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
} else {
@ -163,6 +167,8 @@ class IteratorResource : public ResourceBase {
params.lib = new_state->lib;
params.function_handle_cache = new_state->function_handle_cache.get();
params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
"Iterator", &new_state->iterator));
TF_RETURN_IF_ERROR(
@ -179,6 +185,7 @@ class IteratorResource : public ResourceBase {
params.allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
IteratorContext iter_ctx(std::move(params));
TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
}
@ -233,6 +240,7 @@ class IteratorResource : public ResourceBase {
params.lib = new_state->lib;
params.function_handle_cache = new_state->function_handle_cache.get();
params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
"Iterator", &iterator));
TF_RETURN_IF_ERROR(
@ -284,6 +292,7 @@ class IteratorResource : public ResourceBase {
std::unique_ptr<IteratorBase> iterator;
};
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
const std::unique_ptr<DeviceMgr> device_mgr_ GUARDED_BY(mu_);
std::shared_ptr<State> iterator_state_ GUARDED_BY(mu_);
@ -432,14 +441,14 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
context,
mgr->LookupOrCreate<IteratorResource>(
cinfo_.container(), cinfo_.name(), &resource,
[lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
std::move(device_mgr), std::move(flib_def),
std::move(pflr), lib);
return Status::OK();
}));
[context, lib, &device_mgr, &flib_def, &pflr,
this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
context->env(), output_dtypes_, output_shapes_,
graph_def_version_, std::move(device_mgr),
std::move(flib_def), std::move(pflr), lib);
return Status::OK();
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
@ -522,7 +531,7 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
existing_resource->Unref();
}
IteratorResource* new_resource = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
context->env(), output_dtypes_, output_shapes_, graph_def_version_,
std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
// Create the resource with our chosen name under the resource lookup
// mutex to avoid another kernel racily creating a resource with this
@ -837,11 +846,12 @@ class OneShotIteratorOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(
ctx->resource_manager()->LookupOrCreate<IteratorResource>(
cinfo->container(), cinfo->name(), iterator,
[lib, this, &flib_def, &pflr](IteratorResource** ret)
[ctx, lib, this, &flib_def, &pflr](IteratorResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
nullptr, std::move(flib_def), std::move(pflr), lib);
ctx->env(), output_dtypes_, output_shapes_,
graph_def_version_, nullptr, std::move(flib_def),
std::move(pflr), lib);
return Status::OK();
}));

View File

@ -140,9 +140,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx =
std::make_shared<IteratorContext>(*ctx);
optimize_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_model",
[this, new_ctx]() { OptimizeThread(new_ctx); }));
optimize_thread_ = ctx->StartThread(
"tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); });
}
return Status::OK();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
@ -42,14 +43,15 @@ using MultiDeviceIteratorCallback =
class MultiDeviceIterator : public ResourceBase {
public:
MultiDeviceIterator(
const DataTypeVector& output_types,
Env* env, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const std::vector<string>& devices,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
std::unique_ptr<FunctionHandleCache> function_handle_cache)
: output_types_(output_types),
: unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"),
output_types_(output_types),
output_shapes_(output_shapes),
devices_(devices),
flib_def_(std::move(flib_def)),
@ -82,27 +84,25 @@ class MultiDeviceIterator : public ResourceBase {
*incarnation_id = incarnation_id_;
multi_device_buffer_ = absl::make_unique<MultiDeviceBuffer>(
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator));
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator),
this);
return Status::OK();
}
void GetNextFromShard(IteratorContext* ctx, int shard_num,
void GetNextFromShard(OpKernelContext* ctx, int shard_num,
int64 incarnation_id,
MultiDeviceIteratorCallback callback) {
if (ctx->lib() == lib_) {
tf_shared_lock l(mu_);
multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
std::move(callback));
} else {
IteratorContext::Params params(ctx);
params.lib = lib_;
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
IteratorContext iter_ctx(std::move(params));
tf_shared_lock l(mu_);
multi_device_buffer_->GetNextFromShard(
&iter_ctx, shard_num, incarnation_id, std::move(callback));
}
tf_shared_lock l(mu_);
IteratorContext::Params params(ctx);
params.function_library = lib_def_;
params.lib = lib_;
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
IteratorContext iter_ctx(std::move(params));
multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
std::move(callback));
}
const DataTypeVector& output_types() const { return output_types_; }
@ -133,12 +133,14 @@ class MultiDeviceIterator : public ResourceBase {
class MultiDeviceBuffer {
public:
MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
std::unique_ptr<IteratorBase> host_iterator)
std::unique_ptr<IteratorBase> host_iterator,
MultiDeviceIterator* parent)
: buffer_(size),
size_(size),
max_buffer_size_(max_buffer_size),
incarnation_id_(incarnation_id),
host_iterator_(std::move(host_iterator)) {}
host_iterator_(std::move(host_iterator)),
parent_(parent) {}
~MultiDeviceBuffer() {
{
@ -217,10 +219,12 @@ class MultiDeviceIterator : public ResourceBase {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!background_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
background_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_multi_device_iterator",
std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
this, std::move(ctx_copy))));
background_thread_ =
parent_->unbounded_thread_pool_.get_thread_factory()->StartThread(
"tf_data_multi_device_iterator",
std::bind(
&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
this, std::move(ctx_copy)));
}
}
@ -342,8 +346,10 @@ class MultiDeviceIterator : public ResourceBase {
const int64 max_buffer_size_;
const int64 incarnation_id_;
const std::unique_ptr<IteratorBase> host_iterator_;
MultiDeviceIterator* const parent_; // Not owned.
};
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
@ -413,8 +419,9 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
current_id_.fetch_add(1));
container_name = "AnonymousMultiDeviceIterator";
resource = new MultiDeviceIterator(
output_types_, output_shapes_, devices_, std::move(flib_def),
std::move(pflr), lib, std::move(function_handle_cache));
context->env(), output_types_, output_shapes_, devices_,
std::move(flib_def), std::move(pflr), lib,
std::move(function_handle_cache));
// NOTE: `mgr->Create()` transfers the one reference on `resource` to
// `mgr`.
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
@ -425,11 +432,12 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
OP_REQUIRES_OK(context,
mgr->LookupOrCreate<MultiDeviceIterator>(
container_name, unique_name, &resource,
[this, lib, &flib_def, &pflr,
[this, context, lib, &flib_def, &pflr,
&function_handle_cache](MultiDeviceIterator** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new MultiDeviceIterator(
output_types_, output_shapes_, devices_,
context->env(), output_types_,
output_shapes_, devices_,
std::move(flib_def), std::move(pflr),
lib, std::move(function_handle_cache));
return Status::OK();
@ -557,11 +565,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
},
std::placeholders::_1, std::move(done));
IteratorContext::Params params(ctx);
params.function_library = iterator->function_library();
IteratorContext iter_ctx(std::move(params));
iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
callback);
iterator->GetNextFromShard(ctx, shard_num, incarnation_id,
std::move(callback));
iterator->Unref();
},
std::move(done)));

View File

@ -517,17 +517,15 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!current_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
current_elements_manager_ =
absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_current",
[this, new_ctx]() { CurrentElementsManager(new_ctx); }));
current_elements_manager_ = ctx->StartThread(
"tf_data_parallel_interleave_current",
[this, new_ctx]() { CurrentElementsManager(new_ctx); });
}
if (!future_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
future_elements_manager_ =
absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_future",
[this, new_ctx]() { FutureElementsManager(new_ctx); }));
future_elements_manager_ = ctx->StartThread(
"tf_data_parallel_interleave_future",
[this, new_ctx]() { FutureElementsManager(new_ctx); });
}
}

View File

@ -191,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -269,9 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
if (!prefetch_thread_) {
std::shared_ptr<IteratorContext> new_ctx =
std::make_shared<IteratorContext>(*ctx);
prefetch_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_prefetch",
[this, new_ctx]() { PrefetchThread(new_ctx); }));
prefetch_thread_ = ctx->StartThread(
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}

View File

@ -0,0 +1,156 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
// that can be shared (e.g.) in an `IteratorContext`.
class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
public:
explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) override {
return pool_->RunOnPooledThread(std::move(fn));
}
private:
UnboundedThreadPool* const pool_; // Not owned.
};
// A logical implementation of the `tensorflow::Thread` interface that uses
// physical threads in an `UnboundedThreadPool` to perform the work.
//
// NOTE: This object represents a logical thread of control that may be mapped
// onto the same physical thread as other work items that are submitted to the
// same `UnboundedThreadPool`.
class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
public:
explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
: join_notification_(std::move(join_notification)) {}
~LogicalThreadWrapper() override {
// NOTE: The `Thread` destructor is expected to "join" the created thread,
// but the physical thread may continue to execute after the work for this
// thread is complete. We simulate this by waiting on a notification that
// the `CachedThreadFunc` will notify when the thread's work function is
// complete.
join_notification_->WaitForNotification();
}
private:
std::shared_ptr<Notification> join_notification_;
};
UnboundedThreadPool::~UnboundedThreadPool() {
{
mutex_lock l(work_queue_mu_);
// Wake up all `CachedThreadFunc` threads and cause them to terminate before
// joining them when `threads_` is cleared.
cancelled_ = true;
work_queue_cv_.notify_all();
if (!work_queue_.empty()) {
LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
<< "deleted with pending work in its queue. This may indicate "
<< "a potential use-after-free bug.";
}
}
{
mutex_lock l(thread_pool_mu_);
// Clear the list of pooled threads, which will eventually terminate due to
// the previous notification.
//
// NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
// no subsequent calls to `this->StartThread()` should be issued after the
// destructor starts.
thread_pool_.clear();
}
}
std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
return std::make_shared<LogicalThreadFactory>(this);
}
size_t UnboundedThreadPool::size() {
tf_shared_lock l(thread_pool_mu_);
return thread_pool_.size();
}
std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
std::function<void()> fn) {
auto join_notification = std::make_shared<Notification>();
bool all_threads_busy;
{
// Enqueue a work item for the new thread's function, and wake up a
// cached thread to process it.
mutex_lock l(work_queue_mu_);
work_queue_.push_back({std::move(fn), join_notification});
work_queue_cv_.notify_one();
// NOTE: The queue may be non-empty, so we must account for queued work when
// considering how many threads are free.
all_threads_busy = work_queue_.size() > num_idle_threads_;
}
if (all_threads_busy) {
// Spawn a new physical thread to process the given function.
// NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
// at the beginning of its work loop.
Thread* new_thread = env_->StartThread(
{}, thread_name_,
std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
mutex_lock l(thread_pool_mu_);
thread_pool_.emplace_back(new_thread);
}
return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
}
void UnboundedThreadPool::PooledThreadFunc() {
while (true) {
WorkItem work_item;
{
mutex_lock l(work_queue_mu_);
++num_idle_threads_;
while (!cancelled_ && work_queue_.empty()) {
// Wait for a new work function to be submitted, or the cache to be
// destroyed.
work_queue_cv_.wait(l);
}
if (cancelled_) {
return;
}
work_item = std::move(work_queue_.front());
work_queue_.pop_front();
--num_idle_threads_;
}
work_item.work_function();
// Notify any thread that has "joined" the cached thread for this work item.
work_item.done_notification->Notify();
}
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
#define TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
#include <deque>
#include <memory>
#include <vector>
#include "tensorflow/core/framework/thread_factory.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a
// potentially large number of "logical" threads onto a smaller number of
// "physical" threads. The multiplexing is achieved by maintaining an internal
// pool of long-running "physical" threads that are used to execute the
// "logical" threads. Like a regular thread, a "logical" thread may block on
// other threads, and the size of the pool will increase to ensure that progress
// is made. This mechanism is recommended in situations where short-lived
// threads are created repeatedly, to avoid the overhead and memory
// fragmentation that can result from excessive thread creation.
class UnboundedThreadPool {
public:
UnboundedThreadPool(Env* env, const string& thread_name)
: env_(env), thread_name_(thread_name) {}
~UnboundedThreadPool();
// Returns an implementation of `ThreadFactory` that can be used to create
// logical threads in this pool.
std::shared_ptr<ThreadFactory> get_thread_factory();
// Returns the current number of threads in this pool.
size_t size();
private:
class LogicalThreadFactory;
class LogicalThreadWrapper;
struct WorkItem {
std::function<void()> work_function;
std::shared_ptr<Notification> done_notification;
};
std::unique_ptr<Thread> RunOnPooledThread(std::function<void()> fn);
void PooledThreadFunc();
Env* const env_; // Not owned.
const string thread_name_;
mutex work_queue_mu_;
condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
std::deque<WorkItem> work_queue_ GUARDED_BY(work_queue_mu_);
mutex thread_pool_mu_;
std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_

View File

@ -0,0 +1,143 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
TEST(UnboundedThreadPool, SingleThread) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create a thread that updates a variable, and ensure that it runs to
// completion.
std::atomic<int> i(0);
auto thread = thread_factory->StartThread("", [&i]() { ++i; });
thread.reset();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(1, i);
}
TEST(UnboundedThreadPool, MultipleThreads) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create ten threads that update a variable, and ensure that they all run
// to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 10;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i]() { ++i; }));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create 1000 threads that sleep for a random period of time then update a
// variable, and ensure that they all run to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 1000;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i]() {
Env::Default()->SleepForMicroseconds(random::New64() % 10);
++i;
}));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create ten threads that each create ten threads that update a variable, and
// ensure that they all run to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 10;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i, thread_factory]() {
std::vector<std::unique_ptr<Thread>> nested_threads;
for (int k = 0; k < kNumThreadsToCreate; ++k) {
nested_threads.push_back(
thread_factory->StartThread("", [&i]() { ++i; }));
}
nested_threads.clear();
}));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, MultipleBlockingThreads) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
std::vector<std::unique_ptr<Thread>> threads;
// Create multiple waves (with increasing sizes) of threads that all block
// before returning, and
// ensure that we create the appropriate number of threads and terminate
// correctly.
std::vector<int> round_sizes = {5, 10, 15, 20};
for (const int round_size : round_sizes) {
Notification n;
BlockingCounter bc(round_size);
for (int j = 0; j < round_size; ++j) {
threads.push_back(thread_factory->StartThread("", [&bc, &n]() {
bc.DecrementCount();
// Block until `n` is notified, so that all ten threads must been
// created before the first one completes.
n.WaitForNotification();
}));
}
// Wait until all threads have started. Since the number of threads in each
// wave is increasing, we should have at least that number of threads in the
// pool.
bc.Wait();
// NOTE: There is a benign race between a new round starting and the
// physical threads from the previous round returning to the pool, so we may
// create more threads than the round_size.
EXPECT_GE(pool.size(), round_size);
n.Notify();
threads.clear();
}
}
} // namespace
} // namespace data
} // namespace tensorflow