Getting rid of the threadpool from FunctionBufferingResource. It wasn't really serving much purpose apart from moving all the function logic execution onto those threads and making the destruction of the resource quite complicated.

PiperOrigin-RevId: 191017836
This commit is contained in:
Rohan Jain 2018-03-29 18:18:28 -07:00 committed by TensorFlower Gardener
parent ac39aec50f
commit 36c2d9a280
14 changed files with 92 additions and 57 deletions

View File

@ -40,8 +40,7 @@ class FunctionBufferingResource : public ResourceBase {
const NameAttrList& func, int64 buffer_size,
const string& source_device,
const string& target_device,
const std::vector<Tensor>& func_args,
int64 thread_pool_size)
const std::vector<Tensor>& func_args)
: lib_(lib),
pflr_(std::move(pflr)),
func_(func),
@ -52,22 +51,10 @@ class FunctionBufferingResource : public ResourceBase {
handle_(kInvalidHandle),
is_buffering_(false),
end_of_sequence_(false),
cancelled_(false) {
if (thread_pool_size > 0) {
thread_pool_ = new thread::ThreadPool(Env::Default(), ThreadOptions(),
"buffer_resource", thread_pool_size,
false /* low_latency_hint */);
runner_ = [this](std::function<void()> c) {
thread_pool_->Schedule(std::move(c));
};
}
}
cancelled_(false) {}
~FunctionBufferingResource() override {
Cancel();
if (thread_pool_ != nullptr) {
delete thread_pool_;
}
}
string DebugString() override {
@ -179,17 +166,12 @@ class FunctionBufferingResource : public ResourceBase {
for (int i = 0; i < cancellation_callbacks.size(); ++i) {
cancellation_callbacks[i](cancellation_buffer_elements[i]);
}
// We only wait on cond_var_ in the destructor, so there would atmost be
// one waiter to notify.
cond_var_.notify_one();
cond_var_.notify_all();
return;
}
FunctionLibraryRuntime::Options opts;
// Copied from CapturedFunction::generate_step_id();
opts.step_id = -std::abs(static_cast<int64>(random::New64()));
if (runner_ != nullptr) {
opts.runner = &runner_;
}
opts.source_device = source_device_;
AllocatorAttributes arg_alloc_attr;
arg_alloc_attr.set_on_host(true);
@ -251,11 +233,9 @@ class FunctionBufferingResource : public ResourceBase {
const string source_device_;
const string target_device_;
const std::vector<Tensor> func_args_;
thread::ThreadPool* thread_pool_ = nullptr;
FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
std::function<void(std::function<void()>)> runner_ = nullptr;
bool is_buffering_ GUARDED_BY(mu_);
bool end_of_sequence_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_);
@ -270,7 +250,6 @@ class FunctionBufferResourceHandleOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_));
}
~FunctionBufferResourceHandleOp() override {
@ -318,7 +297,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
this](FunctionBufferingResource** ptr) {
*ptr = new FunctionBufferingResource(
clone_lib, std::move(pflr), func_, buffer_size_,
source_device, target_device, func_args, thread_pool_size_);
source_device, target_device, func_args);
return Status::OK();
}));
core::ScopedUnref s(buffer);
@ -340,7 +319,6 @@ class FunctionBufferResourceHandleOp : public OpKernel {
int64 buffer_size_;
string container_;
string name_;
int64 thread_pool_size_;
};
REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")

View File

@ -53,7 +53,6 @@ REGISTER_OP("FunctionBufferingResource")
.Attr("container: string")
.Attr("f: func")
.Attr("buffer_size: int")
.Attr("thread_pool_size: int")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Creates a resource that fills up a buffer by making function calls.
@ -63,7 +62,6 @@ target_device: Target device to execute the function on.
resource: Handle to the resource created.
f: Function to be executed.
buffer_size: Size of the buffer.
thread_pool_size: Size of the threadpool doing the prefetching.
container: If non-empty, this resource is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this resource will be shared under the given name

View File

@ -70,7 +70,6 @@ class StagingAreaOpsTest(test.TestCase):
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
thread_pool_size=2,
shared_name=buffer_name)
with ops.device(device1):

View File

@ -36,7 +36,6 @@ def function_buffering_resource(string_arg,
target_device,
f,
buffer_size,
thread_pool_size=0,
container="",
shared_name=None,
name=None):
@ -48,7 +47,6 @@ def function_buffering_resource(string_arg,
shared_name=shared_name,
f=f,
buffer_size=buffer_size,
thread_pool_size=thread_pool_size,
container=container,
name=name)
@ -90,8 +88,7 @@ class _PrefetchToDeviceIterator(object):
target_device=gen_dataset_ops.iterator_get_device(
input_iterator._iterator_resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
thread_pool_size=0)
buffer_size=buffer_size)
def get_next(self, name=None):
"""See @{tf.data.Iterator.get_next}."""

View File

@ -59,8 +59,7 @@ class _PrefetchToDeviceIterator(object):
f=_prefetch_fn,
target_device=target_device,
string_arg=input_iterator_handle,
buffer_size=buffer_size,
thread_pool_size=0)
buffer_size=buffer_size)
self._buffering_resources.append(buffer_resource_handle)
def get_next(self, name=None):

View File

@ -98,7 +98,6 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
f=remote_fn,
target_device=target,
buffer_size=10,
thread_pool_size=1,
container="",
shared_name=_generate_shared_name("function_buffer_resource"))
self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long

View File

@ -3273,6 +3273,18 @@ tf_cc_test(
],
)
tf_cc_test(
name = "common_runtime_process_util_test",
size = "small",
srcs = ["common_runtime/process_util_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core_cpu_internal",
":test",
":test_main",
],
)
tf_cc_test(
name = "common_runtime_rendezvous_util_test",
size = "small",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@ -69,20 +70,6 @@ auto* direct_session_runs = monitoring::Counter<0>::New(
"/tensorflow/core/direct_session_runs",
"The number of times DirectSession::Run() has been called.");
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 t = options.config.inter_op_parallelism_threads();
if (t != 0) return t;
// Default to using the number of cores available in the process.
return port::NumSchedulableCPUs();
}
thread::ThreadPool* NewThreadPoolFromSessionOptions(
const SessionOptions& options) {
const int32 num_threads = NumInterOpThreadsFromSessionOptions(options);
VLOG(1) << "Direct session inter op parallelism threads: " << num_threads;
return new thread::ThreadPool(options.env, "Compute", num_threads);
}
Status NewThreadPoolFromThreadPoolOptions(
const SessionOptions& options,
const ThreadPoolOptionProto& thread_pool_options, int pool_number,

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/process_util.h"
namespace tensorflow {
EagerContext::EagerContext(const SessionOptions& opts,
@ -25,9 +27,10 @@ EagerContext::EagerContext(const SessionOptions& opts,
device_manager_(std::move(device_mgr)),
devices_(device_manager_->ListDevices()),
rendezvous_(rendezvous),
pflr_(new ProcessFunctionLibraryRuntime(device_manager_.get(), opts.env,
TF_GRAPH_DEF_VERSION,
&func_lib_def_, {})),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
{}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
async_default_(async) {
if (async_default_) {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@ -160,6 +161,8 @@ class EagerContext {
FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
OpRegistry::Global(), {}};
std::unique_ptr<thread::ThreadPool> thread_pool_;
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].

View File

@ -796,16 +796,17 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
done(status);
};
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
parent_->Run(run_opts, handle, args, rets, done);
return;
}
if (run_opts.runner == nullptr) {
run_opts.runner = &default_runner_;
}
DCHECK(run_opts.runner != nullptr);
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
parent_->Run(run_opts, handle, args, rets, done);
return;
}
Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller.
exec_args->step_id = run_opts.step_id;

View File

@ -46,6 +46,20 @@ thread::ThreadPool* ComputePool(const SessionOptions& options) {
return compute_pool;
}
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 t = options.config.inter_op_parallelism_threads();
if (t != 0) return t;
// Default to using the number of cores available in the process.
return port::NumSchedulableCPUs();
}
thread::ThreadPool* NewThreadPoolFromSessionOptions(
const SessionOptions& options) {
const int32 num_threads = NumInterOpThreadsFromSessionOptions(options);
VLOG(1) << "Direct session inter op parallelism threads: " << num_threads;
return new thread::ThreadPool(options.env, "Compute", num_threads);
}
void SchedClosure(std::function<void()> closure) {
if (port::Tracing::IsActive()) {
const uint64 id = port::Tracing::UniqueId();

View File

@ -30,6 +30,13 @@ namespace tensorflow {
// using 'options'. Caller does not take ownership over threadpool.
thread::ThreadPool* ComputePool(const SessionOptions& options);
// Returns number of inter op threads.
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options);
// Creates a thread pool with number of inter op threads.
thread::ThreadPool* NewThreadPoolFromSessionOptions(
const SessionOptions& options);
// Schedule "closure" in the default thread queue.
void SchedClosure(std::function<void()> closure);

View File

@ -0,0 +1,38 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(ProcessUtilTest, NumThreads) {
SessionOptions opts;
opts.config.set_inter_op_parallelism_threads(10);
EXPECT_EQ(10, NumInterOpThreadsFromSessionOptions(opts));
}
TEST(ProcessUtilTest, ThreadPool) {
SessionOptions opts;
opts.config.set_inter_op_parallelism_threads(10);
thread::ThreadPool* pool = NewThreadPoolFromSessionOptions(opts);
EXPECT_EQ(10, pool->NumThreads());
delete pool;
}
} // anonymous namespace
} // namespace tensorflow