Allowing for FunctionLibraryRuntime::Run calls to not be provided with a runner to execute kernels with. In that case, it defaults to using the threadpool provided by the device.

Also makes sure each device has a default threadpool to fall back on.

PiperOrigin-RevId: 188520648
This commit is contained in:
Rohan Jain 2018-03-09 12:20:32 -08:00 committed by TensorFlower Gardener
parent 61a744fffb
commit 20dfc25c37
12 changed files with 221 additions and 96 deletions

View File

@ -41,7 +41,7 @@ class TestEnv {
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
&flib_def_, {}, nullptr);
&flib_def_, nullptr, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {

View File

@ -1181,7 +1181,7 @@ Status DirectSession::GetOrCreateExecutors(
}
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, graph_def_version,
func_info->flib_def.get(), optimizer_opts));
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -868,59 +869,14 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
TF_ASSERT_OK(session->Close());
}
class BlockingOpState {
public:
void AwaitState(int awaiting_state) {
mutex_lock ml(mu_);
while (state_ != awaiting_state) {
cv_.wait(ml);
}
}
void MoveToState(int expected_current, int next) {
mutex_lock ml(mu_);
CHECK_EQ(expected_current, state_);
state_ = next;
cv_.notify_all();
}
private:
mutex mu_;
condition_variable cv_;
int state_ = 0;
};
static BlockingOpState* blocking_op_state = nullptr;
// BlockingOp blocks on the global <blocking_op_state's> state,
// and also updates it when it is unblocked and finishing computation.
class BlockingOp : public OpKernel {
public:
explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
blocking_op_state->MoveToState(0, 1);
blocking_op_state->AwaitState(2);
blocking_op_state->MoveToState(2, 3);
Tensor* out = nullptr;
const Tensor& in = ctx->input(0);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->flat<float>() = in.flat<float>();
}
};
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
static void TestSessionInterOpThreadsImpl(bool use_function_lib,
bool use_global_pools) {
using test::function::blocking_op_state;
using test::function::BlockingOpState;
FunctionDefLibrary library_graph_def;
if (use_function_lib) {
const string lib = R"proto(
signature: {
name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT }
output_arg: { name: "y" type: DT_FLOAT }}
node_def: { name: "y" op: "BlockingOp" input: "x" }
ret: { key: "y" value: "y:y:0" } )proto";
CHECK(protobuf::TextFormat::ParseFromString(
lib, library_graph_def.add_function()));
*library_graph_def.add_function() = test::function::BlockingOpFn();
}
FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/optimizer_cse.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
@ -141,6 +142,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
int graph_def_version,
const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent);
@ -194,6 +196,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* const base_lib_def_;
GraphOptimizer optimizer_;
const CustomKernelCreator custom_kernel_creator_;
Executor::Args::Runner default_runner_;
const string device_name_;
std::function<Status(const string&, const OpDef**)> get_func_sig_;
@ -243,6 +246,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent)
@ -253,6 +257,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
base_lib_def_(lib_def),
optimizer_(optimizer_options),
custom_kernel_creator_(std::move(custom_kernel_creator)),
default_runner_(nullptr),
device_name_(device_ == nullptr
? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
: device_->name()),
@ -264,6 +269,18 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
return CreateKernel(ndef, kernel);
};
thread::ThreadPool* pool = nullptr;
if (device_ != nullptr) {
pool = device_->tensorflow_device_thread_pool();
}
if (pool == nullptr) {
pool = default_thread_pool;
}
if (pool != nullptr) {
default_runner_ = [pool](Executor::Args::Closure c) {
pool->Schedule(std::move(c));
};
}
}
FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
@ -768,6 +785,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
if (run_opts.runner == nullptr) {
run_opts.runner = &default_runner_;
}
DCHECK(run_opts.runner != nullptr);
Executor::Args* exec_args = new Executor::Args;
@ -854,6 +874,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
done(s);
return;
}
if (run_opts.runner == nullptr) {
run_opts.runner = &default_runner_;
}
DCHECK(run_opts.runner != nullptr);
Executor::Args* exec_args = new Executor::Args;
@ -942,21 +965,21 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent) {
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
std::move(custom_kernel_creator), parent));
device_mgr, env, device, graph_def_version, lib_def, thread_pool,
optimizer_options, std::move(custom_kernel_creator), parent));
}
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
ProcessFunctionLibraryRuntime* parent) {
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
lib_def, optimizer_options,
lib_def, thread_pool, optimizer_options,
GetCustomCreatorSingleton()->Get(), parent);
}

View File

@ -55,7 +55,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent);
@ -65,7 +65,7 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
ProcessFunctionLibraryRuntime* parent);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@ -135,7 +136,8 @@ TEST_F(FunctionTest, WXPlusB) {
class FunctionLibraryRuntimeTest : public ::testing::Test {
protected:
void Init(const std::vector<FunctionDef>& flib) {
void Init(const std::vector<FunctionDef>& flib,
thread::ThreadPool* default_thread_pool = nullptr) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
@ -149,7 +151,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
device_mgr_.reset(new DeviceMgr(devices_));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, nullptr /* cluster_flr */));
opts, default_thread_pool, nullptr /* cluster_flr */));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
@ -158,16 +160,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@ -183,7 +189,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = out[i];
}
EXPECT_GE(call_count, 1); // Test runner is used.
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
@ -204,24 +212,25 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
std::vector<Tensor*> rets, bool add_runner = true) {
return InstantiateAndRun(flr, name, attrs,
FunctionLibraryRuntime::InstantiateOptions(), args,
std::move(rets));
std::move(rets), add_runner);
}
Status InstantiateAndRun(
FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
FunctionLibraryRuntime::Options opts;
status = Run(flr, handle, opts, args, rets);
status = Run(flr, handle, opts, args, rets, add_runner);
if (!status.ok()) return status;
// Release the handle and try running again. It should not succeed.
@ -237,16 +246,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
@ -258,7 +271,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
EXPECT_GE(call_count, 1); // Test runner is used.
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
@ -447,7 +462,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
{
// Simple case: instantiating with no state_handle.
for (int32 expected : {6, 4}) {
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@ -460,7 +475,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
EXPECT_EQ(handle, handle_non_isolated);
for (int32 expected : {0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@ -475,7 +490,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@ -490,7 +505,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@ -507,7 +522,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
@ -515,6 +530,59 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
using test::function::blocking_op_state;
using test::function::BlockingOpState;
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1);
Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp);
auto x = test::AsScalar<float>(1.3);
Tensor y;
blocking_op_state = new BlockingOpState();
thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
bool finished_running = false;
tp1->Schedule([&x, &y, &finished_running, this]() {
TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y},
false /* add_runner */));
finished_running = true;
});
// InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked.
EXPECT_FALSE(finished_running);
FunctionLibraryRuntime::Handle h;
TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h));
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
Tensor y1;
std::atomic<int32> num_done(0);
FunctionLibraryRuntime::Options opts;
for (int i = 0; i < 4; ++i) {
tp1->Schedule([&h, &x1, &y1, &opts, &num_done, this]() {
TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */));
num_done.fetch_add(1);
});
}
// All the 4 Run() calls should be blocked because the runner is occupied.
EXPECT_EQ(0, num_done.load());
blocking_op_state->AwaitState(1);
blocking_op_state->MoveToState(1, 2);
// Now the runner should be unblocked and all the other Run() calls should
// proceed.
blocking_op_state->AwaitState(3);
blocking_op_state->MoveToState(3, 0);
delete tp1;
EXPECT_TRUE(finished_running);
EXPECT_EQ(4, num_done.load());
delete blocking_op_state;
blocking_op_state = nullptr;
delete tp;
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@ -787,7 +855,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
s.WithOpName("x4/x2/scale/_12__cf__6")
s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@ -993,13 +1061,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
s.WithOpName("scale/_6__cf__11")
s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
s.WithOpName("Func/_1/sy/_5__cf__10")
s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
@ -1247,14 +1315,14 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},

View File

@ -58,6 +58,59 @@ FunctionDef FindDevice() {
{{{"device_name"}, "FindDeviceOp", {}, {}}});
}
void BlockingOpState::AwaitState(int awaiting_state) {
mutex_lock ml(mu_);
while (state_ != awaiting_state) {
cv_.wait(ml);
}
}
void BlockingOpState::MoveToState(int expected_current, int next) {
mutex_lock ml(mu_);
CHECK_EQ(expected_current, state_);
state_ = next;
cv_.notify_all();
}
BlockingOpState* blocking_op_state = nullptr;
// BlockingOp blocks on the global <blocking_op_state's> state,
// and also updates it when it is unblocked and finishing computation.
class BlockingOp : public OpKernel {
public:
explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
blocking_op_state->MoveToState(0, 1);
blocking_op_state->AwaitState(2);
blocking_op_state->MoveToState(2, 3);
Tensor* out = nullptr;
const Tensor& in = ctx->input(0);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->flat<float>() = in.flat<float>();
}
};
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
REGISTER_OP("BlockingOp")
.Input("x: float")
.Output("y: float")
.Doc("")
.SetShapeFn(shape_inference::UnknownShape);
FunctionDef BlockingOpFn() {
return FDH::Define(
// Name
"BlockingOpFn",
// Args
{"x: float"},
// Return values
{"y: float"},
// Attr def
{},
// Nodes
{{{"y"}, "BlockingOp", {"x"}, {}}});
}
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,
gtl::ArraySlice<Input> inputs) {

View File

@ -25,6 +25,22 @@ namespace function {
// {} -> y:DT_STRING (device where this op runs).
FunctionDef FindDevice();
class BlockingOpState {
public:
void AwaitState(int awaiting_state);
void MoveToState(int expected_current, int next);
private:
mutex mu_;
condition_variable cv_;
int state_ = 0;
};
extern BlockingOpState* blocking_op_state;
FunctionDef BlockingOpFn();
// Adds a function call to the given scope and returns the output for the node.
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,

View File

@ -42,21 +42,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent)
: device_mgr_(device_mgr),
lib_def_(lib_def),
default_thread_pool_(default_thread_pool),
next_handle_(0),
parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[nullptr] =
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
lib_def, optimizer_options, this);
flr_map_[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
optimizer_options, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
flr_map_[d] =
NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
lib_def, optimizer_options, this);
flr_map_[d] = NewFunctionLibraryRuntime(
device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
optimizer_options, this);
}
}
@ -65,21 +67,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent)
: device_mgr_(device_mgr),
lib_def_(lib_def),
default_thread_pool_(default_thread_pool),
next_handle_(0),
parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
std::move(custom_kernel_creator), this);
nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
optimizer_options, std::move(custom_kernel_creator), this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
flr_map_[d] = NewFunctionLibraryRuntime(
device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
custom_kernel_creator, this);
device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
optimizer_options, custom_kernel_creator, this);
}
}
@ -370,7 +374,8 @@ Status ProcessFunctionLibraryRuntime::Clone(
out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
out_pflr->reset(new ProcessFunctionLibraryRuntime(
device_mgr_, env, graph_def_version, out_lib_def->get(),
optimizer_options, std::move(custom_kernel_creator), parent_));
optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
parent_));
return Status::OK();
}

View File

@ -33,6 +33,7 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr);
// With `custom_kernel_creator`.
@ -41,6 +42,7 @@ class ProcessFunctionLibraryRuntime {
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
thread::ThreadPool* thread_pool,
DistributedFunctionLibraryRuntime* parent);
// Sends `tensors_to_send` from `source_device` to `target_device` using
@ -174,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* const device_mgr_;
const FunctionLibraryDefinition* lib_def_;
thread::ThreadPool* default_thread_pool_;
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);

View File

@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
cluster_flr_.reset(new TestClusterFLR());
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, cluster_flr_.get()));
opts, nullptr, cluster_flr_.get()));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
@ -153,7 +153,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
new ProcessFunctionLibraryRuntime(
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
lib_def.get(), opts, nullptr /* cluster_flr */));
lib_def.get(), opts, nullptr, nullptr /* cluster_flr */));
FunctionLibraryRuntime* flr =
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
EXPECT_NE(flr, nullptr);

View File

@ -134,7 +134,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_, worker_env_->env, gdef.versions().producer(),
item->lib_def.get(), graph_options.optimizer_options(), cluster_flr));
item->lib_def.get(), graph_options.optimizer_options(),
worker_env_->compute_pool, cluster_flr));
// Constructs the graph out of "gdef".
Graph graph(OpRegistry::Global());