From 70da1fe25d97b738dd22f8acecf4c329ab97610e Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 1 Mar 2019 18:22:18 -0800 Subject: [PATCH] [tf.data] Add an unbounded thread pool to iterator resources. The previous implementation of many core `tf.data` transformations (e.g. `Dataset.prefetch()`) would create one or more threads each time an iterator over those datasets is created (e.g. `ds.prefetch(N).repeat(100)` would create and destroy 100 threads). In addition to the overhead of thread creation, this interacts poorly with some malloc implementations, and can contribute to memory fragmentation. The new implementation maintains an unbounded pool of physical threads in each iterator (or `MultiDeviceIterator`) resource, and returns logical "threads" to that pool when their work is complete instead of exiting from them. PiperOrigin-RevId: 236413014 --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/dataset.h | 23 ++- tensorflow/core/framework/thread_factory.h | 42 +++++ tensorflow/core/kernels/data/BUILD | 25 +++ .../experimental/choose_fastest_dataset_op.cc | 6 +- .../experimental/map_and_batch_dataset_op.cc | 6 +- .../numa_map_and_batch_dataset_op.cc | 10 +- .../parallel_interleave_dataset_op.cc | 8 +- tensorflow/core/kernels/data/iterator_ops.cc | 38 +++-- .../core/kernels/data/model_dataset_op.cc | 5 +- .../kernels/data/multi_device_iterator_ops.cc | 71 ++++---- .../data/parallel_interleave_dataset_op.cc | 14 +- .../kernels/data/parallel_map_iterator.cc | 6 +- .../core/kernels/data/prefetch_dataset_op.cc | 5 +- .../kernels/data/unbounded_thread_pool.cc | 156 ++++++++++++++++++ .../core/kernels/data/unbounded_thread_pool.h | 77 +++++++++ .../data/unbounded_thread_pool_test.cc | 143 ++++++++++++++++ 17 files changed, 556 insertions(+), 80 deletions(-) create mode 100644 tensorflow/core/framework/thread_factory.h create mode 100644 tensorflow/core/kernels/data/unbounded_thread_pool.cc create mode 100644 tensorflow/core/kernels/data/unbounded_thread_pool.h create mode 100644 tensorflow/core/kernels/data/unbounded_thread_pool_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0e1ed0ec7b7..34f4a777446 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -925,6 +925,7 @@ tf_cuda_library( "framework/tensor_slice.h", "framework/tensor_types.h", "framework/tensor_util.h", + "framework/thread_factory.h", "framework/tracking_allocator.h", "framework/type_index.h", "framework/type_traits.h", diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index a93d5adb232..0c38801154e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/thread_factory.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" @@ -287,7 +289,8 @@ class IteratorContext { model(ctx->model()), runner(*(ctx->runner())), runner_threadpool_size(ctx->runner_threadpool_size()), - stats_aggregator(ctx->stats_aggregator()) {} + stats_aggregator(ctx->stats_aggregator()), + thread_factory(ctx->thread_factory()) {} explicit Params(OpKernelContext* ctx) : env(ctx->env()), @@ -338,6 +341,10 @@ class IteratorContext { // The `StatsAggregator` object to record statistics about the iterator. std::shared_ptr stats_aggregator = nullptr; + + // A `ThreadFactory` for creating threads used by iterators to perform + // blocking work. + std::shared_ptr thread_factory = nullptr; }; explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} @@ -374,6 +381,20 @@ class IteratorContext { return ¶ms_.runner; } + const std::shared_ptr& thread_factory() { + return params_.thread_factory; + } + + std::unique_ptr StartThread(const string& name, + std::function fn) { + if (params_.thread_factory) { + return params_.thread_factory->StartThread(name, std::move(fn)); + } else { + return absl::WrapUnique( + Env::Default()->StartThread({}, name, std::move(fn))); + } + } + int32 runner_threadpool_size() { return params_.runner_threadpool_size; } std::shared_ptr stats_aggregator() { diff --git a/tensorflow/core/framework/thread_factory.h b/tensorflow/core/framework/thread_factory.h new file mode 100644 index 00000000000..d5bb6dda66b --- /dev/null +++ b/tensorflow/core/framework/thread_factory.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ + +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Thread; + +// Virtual interface for an object that creates threads. +class ThreadFactory { + public: + virtual ~ThreadFactory() {} + + // Runs `fn` asynchronously in a different thread. `fn` may block. + // + // NOTE: The caller is responsible for ensuring that this `ThreadFactory` + // outlives the returned `Thread`. + virtual std::unique_ptr StartThread(const string& name, + std::function fn) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index ecfeb87bbe0..5a89d067fdc 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -129,6 +129,29 @@ tf_cc_test( ], ) +cc_library( + name = "unbounded_thread_pool", + srcs = ["unbounded_thread_pool.cc"], + hdrs = ["unbounded_thread_pool.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "unbounded_thread_pool_test", + srcs = ["unbounded_thread_pool_test.cc"], + deps = [ + ":unbounded_thread_pool", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], @@ -595,6 +618,7 @@ tf_kernel_library( deps = [ ":dataset_utils", ":optional_ops", + ":unbounded_thread_pool", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -612,6 +636,7 @@ tf_kernel_library( srcs = ["multi_device_iterator_ops.cc"], deps = [ ":dataset_utils", + ":unbounded_thread_pool", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index e63208f26a9..bfa2bf6bc46 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -292,10 +292,10 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { for (size_t i = 0, num_inputs = dataset()->inputs_.size(); i < num_inputs; ++i) { threads[i].result = absl::make_unique(); - threads[i].thread.reset(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_merge_", i), + threads[i].thread = ctx->StartThread( + strings::StrCat("tf_data_merge_", i), std::bind(&ChooseFastestIterator::RunnerThread, this, ctx, - threads[i].result.get(), i))); + threads[i].result.get(), i)); } return threads; } diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 540afea92a5..fb7a6204a04 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -514,9 +514,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_map_and_batch", - std::bind(&Iterator::RunnerThread, this, ctx_copy))); + runner_thread_ = ctx->StartThread( + "tf_data_map_and_batch", + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc index a45755fbbb6..ce8a20a783f 100644 --- a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc @@ -926,8 +926,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { if (!new_ctx) { new_ctx = std::make_shared(*ctx); } - workers_[i]->threads.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), + workers_[i]->threads.emplace_back(ctx->StartThread( + strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); })); VLOG(3) << "Worker " << i << ", " << j << " successfully started."; } @@ -936,9 +936,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { if (!new_ctx) { new_ctx = std::make_shared(*ctx); } - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_numa_map_and_batch", - [this, new_ctx] { RunnerThread(new_ctx); })); + runner_thread_ = + ctx->StartThread("tf_data_numa_map_and_batch", + [this, new_ctx] { RunnerThread(new_ctx); }); } VLOG(3) << "All workers & runner thread started."; return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index f6d522078dd..54c1d839e60 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -493,8 +493,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), + worker_threads_.emplace_back(ctx->StartThread( + strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } @@ -592,8 +592,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), + worker_threads_.push_back(ctx->StartThread( + strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7e23ca58ce7..14fb6624ad7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/iterator_ops.h" +#include #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/graph_runner.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -51,14 +53,15 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; class IteratorResource : public ResourceBase { public: - IteratorResource(const DataTypeVector& output_dtypes, + IteratorResource(Env* env, const DataTypeVector& output_dtypes, const std::vector& output_shapes, const int /*unused: graph_def_version*/, std::unique_ptr device_mgr, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib) - : device_mgr_(std::move(device_mgr)), + : unbounded_thread_pool_(env, "tf_data_iterator_resource"), + device_mgr_(std::move(device_mgr)), iterator_state_(std::make_shared( std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)), output_dtypes_(output_dtypes), @@ -77,6 +80,7 @@ class IteratorResource : public ResourceBase { params.function_handle_cache = captured_state->function_handle_cache.get(); params.resource_mgr = &captured_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); return captured_state->iterator->GetNext( IteratorContext(std::move(params)), out_tensors, end_of_sequence); } else { @@ -163,6 +167,8 @@ class IteratorResource : public ResourceBase { params.lib = new_state->lib; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); + TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), "Iterator", &new_state->iterator)); TF_RETURN_IF_ERROR( @@ -179,6 +185,7 @@ class IteratorResource : public ResourceBase { params.allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); IteratorContext iter_ctx(std::move(params)); TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader)); } @@ -233,6 +240,7 @@ class IteratorResource : public ResourceBase { params.lib = new_state->lib; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), "Iterator", &iterator)); TF_RETURN_IF_ERROR( @@ -284,6 +292,7 @@ class IteratorResource : public ResourceBase { std::unique_ptr iterator; }; + UnboundedThreadPool unbounded_thread_pool_; mutex mu_; const std::unique_ptr device_mgr_ GUARDED_BY(mu_); std::shared_ptr iterator_state_ GUARDED_BY(mu_); @@ -432,14 +441,14 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { context, mgr->LookupOrCreate( cinfo_.container(), cinfo_.name(), &resource, - [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - std::move(device_mgr), std::move(flib_def), - std::move(pflr), lib); - return Status::OK(); - })); + [context, lib, &device_mgr, &flib_def, &pflr, + this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource( + context->env(), output_dtypes_, output_shapes_, + graph_def_version_, std::move(device_mgr), + std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); Status s = VerifyResource(resource); if (TF_PREDICT_FALSE(!s.ok())) { @@ -522,7 +531,7 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) { existing_resource->Unref(); } IteratorResource* new_resource = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, + context->env(), output_dtypes_, output_shapes_, graph_def_version_, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); // Create the resource with our chosen name under the resource lookup // mutex to avoid another kernel racily creating a resource with this @@ -837,11 +846,12 @@ class OneShotIteratorOp : public AsyncOpKernel { TF_RETURN_IF_ERROR( ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, - [lib, this, &flib_def, &pflr](IteratorResource** ret) + [ctx, lib, this, &flib_def, &pflr](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - nullptr, std::move(flib_def), std::move(pflr), lib); + ctx->env(), output_dtypes_, output_shapes_, + graph_def_version_, nullptr, std::move(flib_def), + std::move(pflr), lib); return Status::OK(); })); diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 20254234e9d..4b8b68f2a38 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -140,9 +140,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { if (!optimize_thread_) { std::shared_ptr new_ctx = std::make_shared(*ctx); - optimize_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_model", - [this, new_ctx]() { OptimizeThread(new_ctx); })); + optimize_thread_ = ctx->StartThread( + "tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); }); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index d78ed6006c0..8dd7aba895a 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" @@ -42,14 +43,15 @@ using MultiDeviceIteratorCallback = class MultiDeviceIterator : public ResourceBase { public: MultiDeviceIterator( - const DataTypeVector& output_types, + Env* env, const DataTypeVector& output_types, const std::vector& output_shapes, const std::vector& devices, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, std::unique_ptr function_handle_cache) - : output_types_(output_types), + : unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"), + output_types_(output_types), output_shapes_(output_shapes), devices_(devices), flib_def_(std::move(flib_def)), @@ -82,27 +84,25 @@ class MultiDeviceIterator : public ResourceBase { *incarnation_id = incarnation_id_; multi_device_buffer_ = absl::make_unique( - devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator)); + devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator), + this); return Status::OK(); } - void GetNextFromShard(IteratorContext* ctx, int shard_num, + void GetNextFromShard(OpKernelContext* ctx, int shard_num, int64 incarnation_id, MultiDeviceIteratorCallback callback) { - if (ctx->lib() == lib_) { - tf_shared_lock l(mu_); - multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, - std::move(callback)); - } else { - IteratorContext::Params params(ctx); - params.lib = lib_; - params.function_handle_cache = function_handle_cache_.get(); - params.resource_mgr = &resource_mgr_; - IteratorContext iter_ctx(std::move(params)); - tf_shared_lock l(mu_); - multi_device_buffer_->GetNextFromShard( - &iter_ctx, shard_num, incarnation_id, std::move(callback)); - } + tf_shared_lock l(mu_); + IteratorContext::Params params(ctx); + params.function_library = lib_def_; + params.lib = lib_; + params.function_handle_cache = function_handle_cache_.get(); + params.resource_mgr = &resource_mgr_; + params.thread_factory = unbounded_thread_pool_.get_thread_factory(); + + IteratorContext iter_ctx(std::move(params)); + multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + std::move(callback)); } const DataTypeVector& output_types() const { return output_types_; } @@ -133,12 +133,14 @@ class MultiDeviceIterator : public ResourceBase { class MultiDeviceBuffer { public: MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, - std::unique_ptr host_iterator) + std::unique_ptr host_iterator, + MultiDeviceIterator* parent) : buffer_(size), size_(size), max_buffer_size_(max_buffer_size), incarnation_id_(incarnation_id), - host_iterator_(std::move(host_iterator)) {} + host_iterator_(std::move(host_iterator)), + parent_(parent) {} ~MultiDeviceBuffer() { { @@ -217,10 +219,12 @@ class MultiDeviceIterator : public ResourceBase { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!background_thread_) { auto ctx_copy = std::make_shared(*ctx); - background_thread_ = absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_multi_device_iterator", - std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, - this, std::move(ctx_copy)))); + background_thread_ = + parent_->unbounded_thread_pool_.get_thread_factory()->StartThread( + "tf_data_multi_device_iterator", + std::bind( + &MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, std::move(ctx_copy))); } } @@ -342,8 +346,10 @@ class MultiDeviceIterator : public ResourceBase { const int64 max_buffer_size_; const int64 incarnation_id_; const std::unique_ptr host_iterator_; + MultiDeviceIterator* const parent_; // Not owned. }; + UnboundedThreadPool unbounded_thread_pool_; mutex mu_; const DataTypeVector output_types_; const std::vector output_shapes_; @@ -413,8 +419,9 @@ class MultiDeviceIteratorHandleOp : public OpKernel { current_id_.fetch_add(1)); container_name = "AnonymousMultiDeviceIterator"; resource = new MultiDeviceIterator( - output_types_, output_shapes_, devices_, std::move(flib_def), - std::move(pflr), lib, std::move(function_handle_cache)); + context->env(), output_types_, output_shapes_, devices_, + std::move(flib_def), std::move(pflr), lib, + std::move(function_handle_cache)); // NOTE: `mgr->Create()` transfers the one reference on `resource` to // `mgr`. OP_REQUIRES_OK(context, mgr->Create( @@ -425,11 +432,12 @@ class MultiDeviceIteratorHandleOp : public OpKernel { OP_REQUIRES_OK(context, mgr->LookupOrCreate( container_name, unique_name, &resource, - [this, lib, &flib_def, &pflr, + [this, context, lib, &flib_def, &pflr, &function_handle_cache](MultiDeviceIterator** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new MultiDeviceIterator( - output_types_, output_shapes_, devices_, + context->env(), output_types_, + output_shapes_, devices_, std::move(flib_def), std::move(pflr), lib, std::move(function_handle_cache)); return Status::OK(); @@ -557,11 +565,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }, std::placeholders::_1, std::move(done)); - IteratorContext::Params params(ctx); - params.function_library = iterator->function_library(); - IteratorContext iter_ctx(std::move(params)); - iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, - callback); + iterator->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); iterator->Unref(); }, std::move(done))); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index a4b614289b1..4dd5c379c03 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -517,17 +517,15 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!current_elements_manager_) { auto new_ctx = std::make_shared(*ctx); - current_elements_manager_ = - absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_parallel_interleave_current", - [this, new_ctx]() { CurrentElementsManager(new_ctx); })); + current_elements_manager_ = ctx->StartThread( + "tf_data_parallel_interleave_current", + [this, new_ctx]() { CurrentElementsManager(new_ctx); }); } if (!future_elements_manager_) { auto new_ctx = std::make_shared(*ctx); - future_elements_manager_ = - absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_parallel_interleave_future", - [this, new_ctx]() { FutureElementsManager(new_ctx); })); + future_elements_manager_ = ctx->StartThread( + "tf_data_parallel_interleave_future", + [this, new_ctx]() { FutureElementsManager(new_ctx); }); } } diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 1f10804205f..3b0d6d7a449 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -191,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_parallel_map", - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); + runner_thread_ = ctx->StartThread( + "tf_data_parallel_map", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index e479b9ff5e3..9773b492905 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -269,9 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (!prefetch_thread_) { std::shared_ptr new_ctx = std::make_shared(*ctx); - prefetch_thread_ = absl::WrapUnique(ctx->env()->StartThread( - {}, "tf_data_prefetch", - [this, new_ctx]() { PrefetchThread(new_ctx); })); + prefetch_thread_ = ctx->StartThread( + "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); }); } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.cc b/tensorflow/core/kernels/data/unbounded_thread_pool.cc new file mode 100644 index 00000000000..ac12197f1b8 --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool.cc @@ -0,0 +1,156 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool` +// that can be shared (e.g.) in an `IteratorContext`. +class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory { + public: + explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {} + + std::unique_ptr StartThread(const string& name, + std::function fn) override { + return pool_->RunOnPooledThread(std::move(fn)); + } + + private: + UnboundedThreadPool* const pool_; // Not owned. +}; + +// A logical implementation of the `tensorflow::Thread` interface that uses +// physical threads in an `UnboundedThreadPool` to perform the work. +// +// NOTE: This object represents a logical thread of control that may be mapped +// onto the same physical thread as other work items that are submitted to the +// same `UnboundedThreadPool`. +class UnboundedThreadPool::LogicalThreadWrapper : public Thread { + public: + explicit LogicalThreadWrapper(std::shared_ptr join_notification) + : join_notification_(std::move(join_notification)) {} + + ~LogicalThreadWrapper() override { + // NOTE: The `Thread` destructor is expected to "join" the created thread, + // but the physical thread may continue to execute after the work for this + // thread is complete. We simulate this by waiting on a notification that + // the `CachedThreadFunc` will notify when the thread's work function is + // complete. + join_notification_->WaitForNotification(); + } + + private: + std::shared_ptr join_notification_; +}; + +UnboundedThreadPool::~UnboundedThreadPool() { + { + mutex_lock l(work_queue_mu_); + // Wake up all `CachedThreadFunc` threads and cause them to terminate before + // joining them when `threads_` is cleared. + cancelled_ = true; + work_queue_cv_.notify_all(); + if (!work_queue_.empty()) { + LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was " + << "deleted with pending work in its queue. This may indicate " + << "a potential use-after-free bug."; + } + } + + { + mutex_lock l(thread_pool_mu_); + // Clear the list of pooled threads, which will eventually terminate due to + // the previous notification. + // + // NOTE: It is safe to do this while holding `pooled_threads_mu_`, because + // no subsequent calls to `this->StartThread()` should be issued after the + // destructor starts. + thread_pool_.clear(); + } +} + +std::shared_ptr UnboundedThreadPool::get_thread_factory() { + return std::make_shared(this); +} + +size_t UnboundedThreadPool::size() { + tf_shared_lock l(thread_pool_mu_); + return thread_pool_.size(); +} + +std::unique_ptr UnboundedThreadPool::RunOnPooledThread( + std::function fn) { + auto join_notification = std::make_shared(); + bool all_threads_busy; + { + // Enqueue a work item for the new thread's function, and wake up a + // cached thread to process it. + mutex_lock l(work_queue_mu_); + work_queue_.push_back({std::move(fn), join_notification}); + work_queue_cv_.notify_one(); + // NOTE: The queue may be non-empty, so we must account for queued work when + // considering how many threads are free. + all_threads_busy = work_queue_.size() > num_idle_threads_; + } + + if (all_threads_busy) { + // Spawn a new physical thread to process the given function. + // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_` + // at the beginning of its work loop. + Thread* new_thread = env_->StartThread( + {}, thread_name_, + std::bind(&UnboundedThreadPool::PooledThreadFunc, this)); + + mutex_lock l(thread_pool_mu_); + thread_pool_.emplace_back(new_thread); + } + + return absl::make_unique(std::move(join_notification)); +} + +void UnboundedThreadPool::PooledThreadFunc() { + while (true) { + WorkItem work_item; + { + mutex_lock l(work_queue_mu_); + ++num_idle_threads_; + while (!cancelled_ && work_queue_.empty()) { + // Wait for a new work function to be submitted, or the cache to be + // destroyed. + work_queue_cv_.wait(l); + } + if (cancelled_) { + return; + } + work_item = std::move(work_queue_.front()); + work_queue_.pop_front(); + --num_idle_threads_; + } + + work_item.work_function(); + + // Notify any thread that has "joined" the cached thread for this work item. + work_item.done_notification->Notify(); + } +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.h b/tensorflow/core/kernels/data/unbounded_thread_pool.h new file mode 100644 index 00000000000..c84d495b296 --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/thread_factory.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a +// potentially large number of "logical" threads onto a smaller number of +// "physical" threads. The multiplexing is achieved by maintaining an internal +// pool of long-running "physical" threads that are used to execute the +// "logical" threads. Like a regular thread, a "logical" thread may block on +// other threads, and the size of the pool will increase to ensure that progress +// is made. This mechanism is recommended in situations where short-lived +// threads are created repeatedly, to avoid the overhead and memory +// fragmentation that can result from excessive thread creation. +class UnboundedThreadPool { + public: + UnboundedThreadPool(Env* env, const string& thread_name) + : env_(env), thread_name_(thread_name) {} + ~UnboundedThreadPool(); + + // Returns an implementation of `ThreadFactory` that can be used to create + // logical threads in this pool. + std::shared_ptr get_thread_factory(); + + // Returns the current number of threads in this pool. + size_t size(); + + private: + class LogicalThreadFactory; + class LogicalThreadWrapper; + struct WorkItem { + std::function work_function; + std::shared_ptr done_notification; + }; + + std::unique_ptr RunOnPooledThread(std::function fn); + void PooledThreadFunc(); + + Env* const env_; // Not owned. + const string thread_name_; + mutex work_queue_mu_; + condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_); + size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0; + bool cancelled_ GUARDED_BY(work_queue_mu_) = false; + std::deque work_queue_ GUARDED_BY(work_queue_mu_); + mutex thread_pool_mu_; + std::vector> thread_pool_ GUARDED_BY(thread_pool_mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_ diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc new file mode 100644 index 00000000000..f996b4f931b --- /dev/null +++ b/tensorflow/core/kernels/data/unbounded_thread_pool_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +TEST(UnboundedThreadPool, SingleThread) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create a thread that updates a variable, and ensure that it runs to + // completion. + std::atomic i(0); + auto thread = thread_factory->StartThread("", [&i]() { ++i; }); + thread.reset(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(1, i); +} + +TEST(UnboundedThreadPool, MultipleThreads) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create ten threads that update a variable, and ensure that they all run + // to completion. + std::vector> threads; + const int kNumThreadsToCreate = 10; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i]() { ++i; })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create 1000 threads that sleep for a random period of time then update a + // variable, and ensure that they all run to completion. + std::vector> threads; + const int kNumThreadsToCreate = 1000; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i]() { + Env::Default()->SleepForMicroseconds(random::New64() % 10); + ++i; + })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, ConcurrentThreadCreation) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + // Create ten threads that each create ten threads that update a variable, and + // ensure that they all run to completion. + std::vector> threads; + const int kNumThreadsToCreate = 10; + std::atomic i(0); + for (int j = 0; j < kNumThreadsToCreate; ++j) { + threads.push_back(thread_factory->StartThread("", [&i, thread_factory]() { + std::vector> nested_threads; + for (int k = 0; k < kNumThreadsToCreate; ++k) { + nested_threads.push_back( + thread_factory->StartThread("", [&i]() { ++i; })); + } + nested_threads.clear(); + })); + } + threads.clear(); + + EXPECT_GE(pool.size(), 1); + EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate); +} + +TEST(UnboundedThreadPool, MultipleBlockingThreads) { + UnboundedThreadPool pool(Env::Default(), "test"); + auto thread_factory = pool.get_thread_factory(); + + std::vector> threads; + + // Create multiple waves (with increasing sizes) of threads that all block + // before returning, and + // ensure that we create the appropriate number of threads and terminate + // correctly. + std::vector round_sizes = {5, 10, 15, 20}; + + for (const int round_size : round_sizes) { + Notification n; + BlockingCounter bc(round_size); + for (int j = 0; j < round_size; ++j) { + threads.push_back(thread_factory->StartThread("", [&bc, &n]() { + bc.DecrementCount(); + // Block until `n` is notified, so that all ten threads must been + // created before the first one completes. + n.WaitForNotification(); + })); + } + + // Wait until all threads have started. Since the number of threads in each + // wave is increasing, we should have at least that number of threads in the + // pool. + bc.Wait(); + // NOTE: There is a benign race between a new round starting and the + // physical threads from the previous round returning to the pool, so we may + // create more threads than the round_size. + EXPECT_GE(pool.size(), round_size); + n.Notify(); + threads.clear(); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow