diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0e1ed0ec7b7..34f4a777446 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -925,6 +925,7 @@ tf_cuda_library( "framework/tensor_slice.h", "framework/tensor_types.h", "framework/tensor_util.h", + "framework/thread_factory.h", "framework/tracking_allocator.h", "framework/type_index.h", "framework/type_traits.h", diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index a93d5adb232..0c38801154e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/thread_factory.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" @@ -287,7 +289,8 @@ class IteratorContext { model(ctx->model()), runner(*(ctx->runner())), runner_threadpool_size(ctx->runner_threadpool_size()), - stats_aggregator(ctx->stats_aggregator()) {} + stats_aggregator(ctx->stats_aggregator()), + thread_factory(ctx->thread_factory()) {} explicit Params(OpKernelContext* ctx) : env(ctx->env()), @@ -338,6 +341,10 @@ class IteratorContext { // The `StatsAggregator` object to record statistics about the iterator. std::shared_ptr stats_aggregator = nullptr; + + // A `ThreadFactory` for creating threads used by iterators to perform + // blocking work. + std::shared_ptr thread_factory = nullptr; }; explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} @@ -374,6 +381,20 @@ class IteratorContext { return ¶ms_.runner; } + const std::shared_ptr& thread_factory() { + return params_.thread_factory; + } + + std::unique_ptr StartThread(const string& name, + std::function fn) { + if (params_.thread_factory) { + return params_.thread_factory->StartThread(name, std::move(fn)); + } else { + return absl::WrapUnique( + Env::Default()->StartThread({}, name, std::move(fn))); + } + } + int32 runner_threadpool_size() { return params_.runner_threadpool_size; } std::shared_ptr stats_aggregator() { diff --git a/tensorflow/core/framework/thread_factory.h b/tensorflow/core/framework/thread_factory.h new file mode 100644 index 00000000000..d5bb6dda66b --- /dev/null +++ b/tensorflow/core/framework/thread_factory.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ + +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Thread; + +// Virtual interface for an object that creates threads. +class ThreadFactory { + public: + virtual ~ThreadFactory() {} + + // Runs `fn` asynchronously in a different thread. `fn` may block. + // + // NOTE: The caller is responsible for ensuring that this `ThreadFactory` + // outlives the returned `Thread`. + virtual std::unique_ptr StartThread(const string& name, + std::function fn) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index ecfeb87bbe0..5a89d067fdc 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -129,6 +129,29 @@ tf_cc_test( ], ) +cc_library( + name = "unbounded_thread_pool", + srcs = ["unbounded_thread_pool.cc"], + hdrs = ["unbounded_thread_pool.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "unbounded_thread_pool_test", + srcs = ["unbounded_thread_pool_test.cc"], + deps = [ + ":unbounded_thread_pool", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], @@ -595,6 +618,7 @@ tf_kernel_library( deps = [ ":dataset_utils", ":optional_ops", + ":unbounded_thread_pool", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -612,6 +636,7 @@ tf_kernel_library( srcs = ["multi_device_iterator_ops.cc"], deps = [ ":dataset_utils", + ":unbounded_thread_pool", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index e63208f26a9..bfa2bf6bc46 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -292,10 +292,10 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { for (size_t i = 0, num_inputs = dataset()->inputs_.size(); i < num_inputs; ++i) { threads[i].result = absl::make_unique(); - threads[i].thread.reset(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_merge_", i), + threads[i].thread = ctx->StartThread( + strings::StrCat("tf_data_merge_", i), std::bind(&ChooseFastestIterator::RunnerThread, this, ctx, - threads[i].result.get(), i))); + threads[i].result.get(), i)); } return threads; } diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 540afea92a5..fb7a6204a04 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -514,9 +514,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_map_and_batch", - std::bind(&Iterator::RunnerThread, this, ctx_copy))); + runner_thread_ = ctx->StartThread( + "tf_data_map_and_batch", + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc index a45755fbbb6..ce8a20a783f 100644 --- a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc @@ -926,8 +926,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { if (!new_ctx) { new_ctx = std::make_shared(*ctx); } - workers_[i]->threads.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), + workers_[i]->threads.emplace_back(ctx->StartThread( + strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); })); VLOG(3) << "Worker " << i << ", " << j << " successfully started."; } @@ -936,9 +936,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { if (!new_ctx) { new_ctx = std::make_shared(*ctx); } - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_numa_map_and_batch", - [this, new_ctx] { RunnerThread(new_ctx); })); + runner_thread_ = + ctx->StartThread("tf_data_numa_map_and_batch", + [this, new_ctx] { RunnerThread(new_ctx); }); } VLOG(3) << "All workers & runner thread started."; return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index f6d522078dd..54c1d839e60 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -493,8 +493,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), + worker_threads_.emplace_back(ctx->StartThread( + strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } @@ -592,8 +592,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), + worker_threads_.push_back(ctx->StartThread( + strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7e23ca58ce7..14fb6624ad7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/iterator_ops.h" +#include #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/graph_runner.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -51,14 +53,15 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; class IteratorResource : public ResourceBase { public: - IteratorResource(const DataTypeVector& output_dtypes, + IteratorResource(Env* env, const DataTypeVector& output_dtypes, const std::vector& output_shapes, const int /*unused: graph_def_version*/, std::unique_ptr device_mgr, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib) - : device_mgr_(std::move(device_mgr)), + : unbounded_thread_pool_(env, "tf_data_iterator_resource"), + device_mgr_(std::move(device_mgr)), iterator_state_(std::make_shared( std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)), output_dtypes_(output_dtypes), @@ -77,6 +80,7 @@ class IteratorResource : public ResourceBase { params.function_handle_cache = captured_state->function_handle_cache.get(); params.resource_mgr = &captured_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); return captured_state->iterator->GetNext( IteratorContext(std::move(params)), out_tensors, end_of_sequence); } else { @@ -163,6 +167,8 @@ class IteratorResource : public ResourceBase { params.lib = new_state->lib; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); + TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), "Iterator", &new_state->iterator)); TF_RETURN_IF_ERROR( @@ -179,6 +185,7 @@ class IteratorResource : public ResourceBase { params.allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); IteratorContext iter_ctx(std::move(params)); TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader)); } @@ -233,6 +240,7 @@ class IteratorResource : public ResourceBase { params.lib = new_state->lib; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), "Iterator", &iterator)); TF_RETURN_IF_ERROR( @@ -284,6 +292,7 @@ class IteratorResource : public ResourceBase { std::unique_ptr iterator; }; + UnboundedThreadPool unbounded_thread_pool_; mutex mu_; const std::unique_ptr device_mgr_ GUARDED_BY(mu_); std::shared_ptr iterator_state_ GUARDED_BY(mu_); @@ -432,14 +441,14 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { context, mgr->LookupOrCreate( cinfo_.container(), cinfo_.name(), &resource, - [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - std::move(device_mgr), std::move(flib_def), - std::move(pflr), lib); - return Status::OK(); - })); + [context, lib, &device_mgr, &flib_def, &pflr, + this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource( + context->env(), output_dtypes_, output_shapes_, + graph_def_version_, std::move(device_mgr), + std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); Status s = VerifyResource(resource); if (TF_PREDICT_FALSE(!s.ok())) { @@ -522,7 +531,7 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) { existing_resource->Unref(); } IteratorResource* new_resource = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, + context->env(), output_dtypes_, output_shapes_, graph_def_version_, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); // Create the resource with our chosen name under the resource lookup // mutex to avoid another kernel racily creating a resource with this @@ -837,11 +846,12 @@ class OneShotIteratorOp : public AsyncOpKernel { TF_RETURN_IF_ERROR( ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, - [lib, this, &flib_def, &pflr](IteratorResource** ret) + [ctx, lib, this, &flib_def, &pflr](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - nullptr, std::move(flib_def), std::move(pflr), lib); + ctx->env(), output_dtypes_, output_shapes_, + graph_def_version_, nullptr, std::move(flib_def), + std::move(pflr), lib); return Status::OK(); })); diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 20254234e9d..4b8b68f2a38 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -140,9 +140,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { if (!optimize_thread_) { std::shared_ptr new_ctx = std::make_shared(*ctx); - optimize_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_model", - [this, new_ctx]() { OptimizeThread(new_ctx); })); + optimize_thread_ = ctx->StartThread( + "tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); }); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index d78ed6006c0..8dd7aba895a 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" @@ -42,14 +43,15 @@ using MultiDeviceIteratorCallback = class MultiDeviceIterator : public ResourceBase { public: MultiDeviceIterator( - const DataTypeVector& output_types, + Env* env, const DataTypeVector& output_types, const std::vector& output_shapes, const std::vector& devices, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, std::unique_ptr function_handle_cache) - : output_types_(output_types), + : unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"), + output_types_(output_types), output_shapes_(output_shapes), devices_(devices), flib_def_(std::move(flib_def)), @@ -82,27 +84,25 @@ class MultiDeviceIterator : public ResourceBase { *incarnation_id = incarnation_id_; multi_device_buffer_ = absl::make_unique( - devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator)); + devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator), + this); return Status::OK(); } - void GetNextFromShard(IteratorContext* ctx, int shard_num, + void GetNextFromShard(OpKernelContext* ctx, int shard_num, int64 incarnation_id, MultiDeviceIteratorCallback callback) { - if (ctx->lib() == lib_) { - tf_shared_lock l(mu_); - multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, - std::move(callback)); - } else { - IteratorContext::Params params(ctx); - params.lib = lib_; - params.function_handle_cache = function_handle_cache_.get(); - params.resource_mgr = &resource_mgr_; - IteratorContext iter_ctx(std::move(params)); - tf_shared_lock l(mu_); - multi_device_buffer_->GetNextFromShard( - &iter_ctx, shard_num, incarnation_id, std::move(callback)); - } + tf_shared_lock l(mu_); + IteratorContext::Params params(ctx); + params.function_library = lib_def_; + params.lib = lib_; + params.function_handle_cache = function_handle_cache_.get(); + params.resource_mgr = &resource_mgr_; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); + + IteratorContext iter_ctx(std::move(params)); + multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + std::move(callback)); } const DataTypeVector& output_types() const { return output_types_; } @@ -133,12 +133,14 @@ class MultiDeviceIterator : public ResourceBase { class MultiDeviceBuffer { public: MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, - std::unique_ptr host_iterator) + std::unique_ptr host_iterator, + MultiDeviceIterator* parent) : buffer_(size), size_(size), max_buffer_size_(max_buffer_size), incarnation_id_(incarnation_id), - host_iterator_(std::move(host_iterator)) {} + host_iterator_(std::move(host_iterator)), + parent_(parent) {} ~MultiDeviceBuffer() { { @@ -217,10 +219,12 @@ class MultiDeviceIterator : public ResourceBase { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!background_thread_) { auto ctx_copy = std::make_shared(*ctx); - background_thread_ = absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_multi_device_iterator", - std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, - this, std::move(ctx_copy)))); + background_thread_ = + parent_->unbounded_thread_pool_.get_thread_factory()->StartThread( + "tf_data_multi_device_iterator", + std::bind( + &MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, std::move(ctx_copy))); } } @@ -342,8 +346,10 @@ class MultiDeviceIterator : public ResourceBase { const int64 max_buffer_size_; const int64 incarnation_id_; const std::unique_ptr host_iterator_; + MultiDeviceIterator* const parent_; // Not owned. }; + UnboundedThreadPool unbounded_thread_pool_; mutex mu_; const DataTypeVector output_types_; const std::vector output_shapes_; @@ -413,8 +419,9 @@ class MultiDeviceIteratorHandleOp : public OpKernel { current_id_.fetch_add(1)); container_name = "AnonymousMultiDeviceIterator"; resource = new MultiDeviceIterator( - output_types_, output_shapes_, devices_, std::move(flib_def), - std::move(pflr), lib, std::move(function_handle_cache)); + context->env(), output_types_, output_shapes_, devices_, + std::move(flib_def), std::move(pflr), lib, + std::move(function_handle_cache)); // NOTE: `mgr->Create()` transfers the one reference on `resource` to // `mgr`. OP_REQUIRES_OK(context, mgr->Create( @@ -425,11 +432,12 @@ class MultiDeviceIteratorHandleOp : public OpKernel { OP_REQUIRES_OK(context, mgr->LookupOrCreate( container_name, unique_name, &resource, - [this, lib, &flib_def, &pflr, + [this, context, lib, &flib_def, &pflr, &function_handle_cache](MultiDeviceIterator** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new MultiDeviceIterator( - output_types_, output_shapes_, devices_, + context->env(), output_types_, + output_shapes_, devices_, std::move(flib_def), std::move(pflr), lib, std::move(function_handle_cache)); return Status::OK(); @@ -557,11 +565,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }, std::placeholders::_1, std::move(done)); - IteratorContext::Params params(ctx); - params.function_library = iterator->function_library(); - IteratorContext iter_ctx(std::move(params)); - iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, - callback); + iterator->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); iterator->Unref(); }, std::move(done))); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index a4b614289b1..4dd5c379c03 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -517,17 +517,15 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!current_elements_manager_) { auto new_ctx = std::make_shared(*ctx); - current_elements_manager_ = - absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_parallel_interleave_current", - [this, new_ctx]() { CurrentElementsManager(new_ctx); })); + current_elements_manager_ = ctx->StartThread( + "tf_data_parallel_interleave_current", + [this, new_ctx]() { CurrentElementsManager(new_ctx); }); } if (!future_elements_manager_) { auto new_ctx = std::make_shared(*ctx); - future_elements_manager_ = - absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_parallel_interleave_future", - [this, new_ctx]() { FutureElementsManager(new_ctx); })); + future_elements_manager_ = ctx->StartThread( + "tf_data_parallel_interleave_future", + [this, new_ctx]() { FutureElementsManager(new_ctx); }); } } diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 1f10804205f..3b0d6d7a449 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -191,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_parallel_map", - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); + runner_thread_ = ctx->StartThread( + "tf_data_parallel_map", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index e479b9ff5e3..9773b492905 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -269,9 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (!prefetch_thread_) { std::shared_ptr new_ctx = std::make_shared(*ctx); - prefetch_thread_ = absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_prefetch", - [this, new_ctx]() { PrefetchThread(new_ctx); })); + prefetch_thread_ = ctx->StartThread( + "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); }); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.cc b/tensorflow/core/kernels/data/unbounded_thread_pool.cc new file mode 100644 index 00000000000..ac12197f1b8 --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool.cc @@ -0,0 +1,156 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool` +// that can be shared (e.g.) in an `IteratorContext`. +class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory { + public: + explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {} + + std::unique_ptr StartThread(const string& name, + std::function fn) override { + return pool_->RunOnPooledThread(std::move(fn)); + } + + private: + UnboundedThreadPool* const pool_; // Not owned. +}; + +// A logical implementation of the `tensorflow::Thread` interface that uses +// physical threads in an `UnboundedThreadPool` to perform the work. +// +// NOTE: This object represents a logical thread of control that may be mapped +// onto the same physical thread as other work items that are submitted to the +// same `UnboundedThreadPool`. +class UnboundedThreadPool::LogicalThreadWrapper : public Thread { + public: + explicit LogicalThreadWrapper(std::shared_ptr join_notification) + : join_notification_(std::move(join_notification)) {} + + ~LogicalThreadWrapper() override { + // NOTE: The `Thread` destructor is expected to "join" the created thread, + // but the physical thread may continue to execute after the work for this + // thread is complete. We simulate this by waiting on a notification that + // the `CachedThreadFunc` will notify when the thread's work function is + // complete. + join_notification_->WaitForNotification(); + } + + private: + std::shared_ptr join_notification_; +}; + +UnboundedThreadPool::~UnboundedThreadPool() { + { + mutex_lock l(work_queue_mu_); + // Wake up all `CachedThreadFunc` threads and cause them to terminate before + // joining them when `threads_` is cleared. + cancelled_ = true; + work_queue_cv_.notify_all(); + if (!work_queue_.empty()) { + LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was " + << "deleted with pending work in its queue. This may indicate " + << "a potential use-after-free bug."; + } + } + + { + mutex_lock l(thread_pool_mu_); + // Clear the list of pooled threads, which will eventually terminate due to + // the previous notification. + // + // NOTE: It is safe to do this while holding `pooled_threads_mu_`, because + // no subsequent calls to `this->StartThread()` should be issued after the + // destructor starts. + thread_pool_.clear(); + } +} + +std::shared_ptr UnboundedThreadPool::get_thread_factory() { + return std::make_shared(this); +} + +size_t UnboundedThreadPool::size() { + tf_shared_lock l(thread_pool_mu_); + return thread_pool_.size(); +} + +std::unique_ptr UnboundedThreadPool::RunOnPooledThread( + std::function fn) { + auto join_notification = std::make_shared(); + bool all_threads_busy; + { + // Enqueue a work item for the new thread's function, and wake up a + // cached thread to process it. + mutex_lock l(work_queue_mu_); + work_queue_.push_back({std::move(fn), join_notification}); + work_queue_cv_.notify_one(); + // NOTE: The queue may be non-empty, so we must account for queued work when + // considering how many threads are free. + all_threads_busy = work_queue_.size() > num_idle_threads_; + } + + if (all_threads_busy) { + // Spawn a new physical thread to process the given function. + // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_` + // at the beginning of its work loop. + Thread* new_thread = env_->StartThread( + {}, thread_name_, + std::bind(&UnboundedThreadPool::PooledThreadFunc, this)); + + mutex_lock l(thread_pool_mu_); + thread_pool_.emplace_back(new_thread); + } + + return absl::make_unique(std::move(join_notification)); +} + +void UnboundedThreadPool::PooledThreadFunc() { + while (true) { + WorkItem work_item; + { + mutex_lock l(work_queue_mu_); + ++num_idle_threads_; + while (!cancelled_ && work_queue_.empty()) { + // Wait for a new work function to be submitted, or the cache to be + // destroyed. + work_queue_cv_.wait(l); + } + if (cancelled_) { + return; + } + work_item = std::move(work_queue_.front()); + work_queue_.pop_front(); + --num_idle_threads_; + } + + work_item.work_function(); + + // Notify any thread that has "joined" the cached thread for this work item. + work_item.done_notification->Notify(); + } +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.h b/tensorflow/core/kernels/data/unbounded_thread_pool.h new file mode 100644 index 00000000000..c84d495b296 --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/thread_factory.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a +// potentially large number of "logical" threads onto a smaller number of +// "physical" threads. The multiplexing is achieved by maintaining an internal +// pool of long-running "physical" threads that are used to execute the +// "logical" threads. Like a regular thread, a "logical" thread may block on +// other threads, and the size of the pool will increase to ensure that progress +// is made. This mechanism is recommended in situations where short-lived +// threads are created repeatedly, to avoid the overhead and memory +// fragmentation that can result from excessive thread creation. +class UnboundedThreadPool { + public: + UnboundedThreadPool(Env* env, const string& thread_name) + : env_(env), thread_name_(thread_name) {} + ~UnboundedThreadPool(); + + // Returns an implementation of `ThreadFactory` that can be used to create + // logical threads in this pool. + std::shared_ptr get_thread_factory(); + + // Returns the current number of threads in this pool. + size_t size(); + + private: + class LogicalThreadFactory; + class LogicalThreadWrapper; + struct WorkItem { + std::function work_function; + std::shared_ptr done_notification; + }; + + std::unique_ptr RunOnPooledThread(std::function fn); + void PooledThreadFunc(); + + Env* const env_; // Not owned. + const string thread_name_; + mutex work_queue_mu_; + condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_); + size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0; + bool cancelled_ GUARDED_BY(work_queue_mu_) = false; + std::deque work_queue_ GUARDED_BY(work_queue_mu_); + mutex thread_pool_mu_; + std::vector> thread_pool_ GUARDED_BY(thread_pool_mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc new file mode 100644 index 00000000000..f996b4f931b --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +TEST(UnboundedThreadPool, SingleThread) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create a thread that updates a variable, and ensure that it runs to + // completion. + std::atomic i(0); + auto thread = thread_factory->StartThread("", [&i]() { ++i; }); + thread.reset(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(1, i); +} + +TEST(UnboundedThreadPool, MultipleThreads) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create ten threads that update a variable, and ensure that they all run + // to completion. + std::vector> threads; + const int kNumThreadsToCreate = 10; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i]() { ++i; })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create 1000 threads that sleep for a random period of time then update a + // variable, and ensure that they all run to completion. + std::vector> threads; + const int kNumThreadsToCreate = 1000; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i]() { + Env::Default()->SleepForMicroseconds(random::New64() % 10); + ++i; + })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, ConcurrentThreadCreation) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create ten threads that each create ten threads that update a variable, and + // ensure that they all run to completion. + std::vector> threads; + const int kNumThreadsToCreate = 10; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i, thread_factory]() { + std::vector> nested_threads; + for (int k = 0; k < kNumThreadsToCreate; ++k) { + nested_threads.push_back( + thread_factory->StartThread("", [&i]() { ++i; })); + } + nested_threads.clear(); + })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, MultipleBlockingThreads) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + std::vector> threads; + + // Create multiple waves (with increasing sizes) of threads that all block + // before returning, and + // ensure that we create the appropriate number of threads and terminate + // correctly. + std::vector round_sizes = {5, 10, 15, 20}; + + for (const int round_size : round_sizes) { + Notification n; + BlockingCounter bc(round_size); + for (int j = 0; j < round_size; ++j) { + threads.push_back(thread_factory->StartThread("", [&bc, &n]() { + bc.DecrementCount(); + // Block until `n` is notified, so that all ten threads must been + // created before the first one completes. + n.WaitForNotification(); + })); + } + + // Wait until all threads have started. Since the number of threads in each + // wave is increasing, we should have at least that number of threads in the + // pool. + bc.Wait(); + // NOTE: There is a benign race between a new round starting and the + // physical threads from the previous round returning to the pool, so we may + // create more threads than the round_size. + EXPECT_GE(pool.size(), round_size); + n.Notify(); + threads.clear(); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow