Allowing for FunctionLibraryRuntime::Run calls to not be provided with a runner to execute kernels with. In that case, it defaults to using the threadpool provided by the device.
Also makes sure each device has a default threadpool to fall back on. PiperOrigin-RevId: 188520648
This commit is contained in:
parent
61a744fffb
commit
20dfc25c37
@ -41,7 +41,7 @@ class TestEnv {
|
||||
device_mgr_.reset(new DeviceMgr({device}));
|
||||
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
|
||||
device, TF_GRAPH_DEF_VERSION,
|
||||
&flib_def_, {}, nullptr);
|
||||
&flib_def_, nullptr, {}, nullptr);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* function_library_runtime() const {
|
||||
|
@ -1181,7 +1181,7 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
}
|
||||
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), options_.env, graph_def_version,
|
||||
func_info->flib_def.get(), optimizer_opts));
|
||||
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
|
||||
|
||||
GraphOptimizer optimizer(optimizer_opts);
|
||||
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -868,59 +869,14 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
|
||||
TF_ASSERT_OK(session->Close());
|
||||
}
|
||||
|
||||
class BlockingOpState {
|
||||
public:
|
||||
void AwaitState(int awaiting_state) {
|
||||
mutex_lock ml(mu_);
|
||||
while (state_ != awaiting_state) {
|
||||
cv_.wait(ml);
|
||||
}
|
||||
}
|
||||
void MoveToState(int expected_current, int next) {
|
||||
mutex_lock ml(mu_);
|
||||
CHECK_EQ(expected_current, state_);
|
||||
state_ = next;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
condition_variable cv_;
|
||||
int state_ = 0;
|
||||
};
|
||||
static BlockingOpState* blocking_op_state = nullptr;
|
||||
|
||||
// BlockingOp blocks on the global <blocking_op_state's> state,
|
||||
// and also updates it when it is unblocked and finishing computation.
|
||||
class BlockingOp : public OpKernel {
|
||||
public:
|
||||
explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
blocking_op_state->MoveToState(0, 1);
|
||||
blocking_op_state->AwaitState(2);
|
||||
blocking_op_state->MoveToState(2, 3);
|
||||
|
||||
Tensor* out = nullptr;
|
||||
const Tensor& in = ctx->input(0);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
|
||||
out->flat<float>() = in.flat<float>();
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
|
||||
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
|
||||
|
||||
static void TestSessionInterOpThreadsImpl(bool use_function_lib,
|
||||
bool use_global_pools) {
|
||||
using test::function::blocking_op_state;
|
||||
using test::function::BlockingOpState;
|
||||
|
||||
FunctionDefLibrary library_graph_def;
|
||||
if (use_function_lib) {
|
||||
const string lib = R"proto(
|
||||
signature: {
|
||||
name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT }
|
||||
output_arg: { name: "y" type: DT_FLOAT }}
|
||||
node_def: { name: "y" op: "BlockingOp" input: "x" }
|
||||
ret: { key: "y" value: "y:y:0" } )proto";
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
lib, library_graph_def.add_function()));
|
||||
*library_graph_def.add_function() = test::function::BlockingOpFn();
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/gradients.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -141,6 +142,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
|
||||
int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
ProcessFunctionLibraryRuntime* parent);
|
||||
@ -194,6 +196,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
const FunctionLibraryDefinition* const base_lib_def_;
|
||||
GraphOptimizer optimizer_;
|
||||
const CustomKernelCreator custom_kernel_creator_;
|
||||
Executor::Args::Runner default_runner_;
|
||||
const string device_name_;
|
||||
|
||||
std::function<Status(const string&, const OpDef**)> get_func_sig_;
|
||||
@ -243,6 +246,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
ProcessFunctionLibraryRuntime* parent)
|
||||
@ -253,6 +257,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
||||
base_lib_def_(lib_def),
|
||||
optimizer_(optimizer_options),
|
||||
custom_kernel_creator_(std::move(custom_kernel_creator)),
|
||||
default_runner_(nullptr),
|
||||
device_name_(device_ == nullptr
|
||||
? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
|
||||
: device_->name()),
|
||||
@ -264,6 +269,18 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
||||
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
|
||||
return CreateKernel(ndef, kernel);
|
||||
};
|
||||
thread::ThreadPool* pool = nullptr;
|
||||
if (device_ != nullptr) {
|
||||
pool = device_->tensorflow_device_thread_pool();
|
||||
}
|
||||
if (pool == nullptr) {
|
||||
pool = default_thread_pool;
|
||||
}
|
||||
if (pool != nullptr) {
|
||||
default_runner_ = [pool](Executor::Args::Closure c) {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
|
||||
@ -768,6 +785,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
return;
|
||||
}
|
||||
|
||||
if (run_opts.runner == nullptr) {
|
||||
run_opts.runner = &default_runner_;
|
||||
}
|
||||
DCHECK(run_opts.runner != nullptr);
|
||||
|
||||
Executor::Args* exec_args = new Executor::Args;
|
||||
@ -854,6 +874,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
if (run_opts.runner == nullptr) {
|
||||
run_opts.runner = &default_runner_;
|
||||
}
|
||||
DCHECK(run_opts.runner != nullptr);
|
||||
|
||||
Executor::Args* exec_args = new Executor::Args;
|
||||
@ -942,21 +965,21 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
ProcessFunctionLibraryRuntime* parent) {
|
||||
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
|
||||
device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
|
||||
std::move(custom_kernel_creator), parent));
|
||||
device_mgr, env, device, graph_def_version, lib_def, thread_pool,
|
||||
optimizer_options, std::move(custom_kernel_creator), parent));
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
ProcessFunctionLibraryRuntime* parent) {
|
||||
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
|
||||
lib_def, optimizer_options,
|
||||
lib_def, thread_pool, optimizer_options,
|
||||
GetCustomCreatorSingleton()->Get(), parent);
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
ProcessFunctionLibraryRuntime* parent);
|
||||
|
||||
@ -65,7 +65,7 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
ProcessFunctionLibraryRuntime* parent);
|
||||
|
||||
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -135,7 +136,8 @@ TEST_F(FunctionTest, WXPlusB) {
|
||||
|
||||
class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
protected:
|
||||
void Init(const std::vector<FunctionDef>& flib) {
|
||||
void Init(const std::vector<FunctionDef>& flib,
|
||||
thread::ThreadPool* default_thread_pool = nullptr) {
|
||||
SessionOptions options;
|
||||
auto* device_count = options.config.mutable_device_count();
|
||||
device_count->insert({"CPU", 3});
|
||||
@ -149,7 +151,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
device_mgr_.reset(new DeviceMgr(devices_));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, nullptr /* cluster_flr */));
|
||||
opts, default_thread_pool, nullptr /* cluster_flr */));
|
||||
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
|
||||
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
|
||||
@ -158,16 +160,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
|
||||
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
||||
FunctionLibraryRuntime::Options opts,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
||||
bool add_runner = true) {
|
||||
std::atomic<int32> call_count(0);
|
||||
std::function<void(std::function<void()>)> runner =
|
||||
[&call_count](std::function<void()> fn) {
|
||||
++call_count;
|
||||
test::function::FunctionTestSchedClosure(fn);
|
||||
};
|
||||
|
||||
if (add_runner) {
|
||||
opts.runner = &runner;
|
||||
} else {
|
||||
opts.runner = nullptr;
|
||||
}
|
||||
Notification done;
|
||||
opts.runner = &runner;
|
||||
std::vector<Tensor> out;
|
||||
Status status;
|
||||
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
||||
@ -183,7 +189,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
*rets[i] = out[i];
|
||||
}
|
||||
|
||||
EXPECT_GE(call_count, 1); // Test runner is used.
|
||||
if (add_runner) {
|
||||
EXPECT_GE(call_count, 1); // Test runner is used.
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -204,24 +212,25 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
|
||||
test::function::Attrs attrs,
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor*> rets) {
|
||||
std::vector<Tensor*> rets, bool add_runner = true) {
|
||||
return InstantiateAndRun(flr, name, attrs,
|
||||
FunctionLibraryRuntime::InstantiateOptions(), args,
|
||||
std::move(rets));
|
||||
std::move(rets), add_runner);
|
||||
}
|
||||
|
||||
Status InstantiateAndRun(
|
||||
FunctionLibraryRuntime* flr, const string& name,
|
||||
test::function::Attrs attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
||||
bool add_runner = true) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = flr->Instantiate(name, attrs, options, &handle);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
status = Run(flr, handle, opts, args, rets);
|
||||
status = Run(flr, handle, opts, args, rets, add_runner);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
// Release the handle and try running again. It should not succeed.
|
||||
@ -237,16 +246,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
||||
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
|
||||
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
|
||||
bool add_runner = true) {
|
||||
std::atomic<int32> call_count(0);
|
||||
std::function<void(std::function<void()>)> runner =
|
||||
[&call_count](std::function<void()> fn) {
|
||||
++call_count;
|
||||
test::function::FunctionTestSchedClosure(fn);
|
||||
};
|
||||
|
||||
if (add_runner) {
|
||||
opts.runner = &runner;
|
||||
} else {
|
||||
opts.runner = nullptr;
|
||||
}
|
||||
Notification done;
|
||||
opts.runner = &runner;
|
||||
std::vector<Tensor> out;
|
||||
Status status;
|
||||
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
|
||||
@ -258,7 +271,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
return status;
|
||||
}
|
||||
|
||||
EXPECT_GE(call_count, 1); // Test runner is used.
|
||||
if (add_runner) {
|
||||
EXPECT_GE(call_count, 1); // Test runner is used.
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -447,7 +462,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
{
|
||||
// Simple case: instantiating with no state_handle.
|
||||
for (int32 expected : {6, 4}) {
|
||||
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
|
||||
}
|
||||
}
|
||||
@ -460,7 +475,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
|
||||
EXPECT_EQ(handle, handle_non_isolated);
|
||||
for (int32 expected : {0, 1}) {
|
||||
TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
|
||||
}
|
||||
}
|
||||
@ -475,7 +490,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
&handle_isolated));
|
||||
EXPECT_NE(handle, handle_isolated);
|
||||
for (int32 expected : {6, 4, 0, 1}) {
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
|
||||
}
|
||||
}
|
||||
@ -490,7 +505,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
&handle_isolated));
|
||||
EXPECT_NE(handle, handle_isolated);
|
||||
for (int32 expected : {6, 4, 0, 1}) {
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
|
||||
}
|
||||
}
|
||||
@ -507,7 +522,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
&handle_isolated));
|
||||
EXPECT_NE(handle, handle_isolated);
|
||||
for (int32 expected : {6, 4, 0, 1}) {
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
|
||||
}
|
||||
TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
|
||||
@ -515,6 +530,59 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
|
||||
using test::function::blocking_op_state;
|
||||
using test::function::BlockingOpState;
|
||||
|
||||
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1);
|
||||
Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp);
|
||||
|
||||
auto x = test::AsScalar<float>(1.3);
|
||||
Tensor y;
|
||||
blocking_op_state = new BlockingOpState();
|
||||
|
||||
thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
|
||||
bool finished_running = false;
|
||||
tp1->Schedule([&x, &y, &finished_running, this]() {
|
||||
TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y},
|
||||
false /* add_runner */));
|
||||
finished_running = true;
|
||||
});
|
||||
|
||||
// InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked.
|
||||
EXPECT_FALSE(finished_running);
|
||||
|
||||
FunctionLibraryRuntime::Handle h;
|
||||
TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h));
|
||||
|
||||
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
|
||||
Tensor y1;
|
||||
std::atomic<int32> num_done(0);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
tp1->Schedule([&h, &x1, &y1, &opts, &num_done, this]() {
|
||||
TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */));
|
||||
num_done.fetch_add(1);
|
||||
});
|
||||
}
|
||||
// All the 4 Run() calls should be blocked because the runner is occupied.
|
||||
EXPECT_EQ(0, num_done.load());
|
||||
|
||||
blocking_op_state->AwaitState(1);
|
||||
blocking_op_state->MoveToState(1, 2);
|
||||
// Now the runner should be unblocked and all the other Run() calls should
|
||||
// proceed.
|
||||
blocking_op_state->AwaitState(3);
|
||||
blocking_op_state->MoveToState(3, 0);
|
||||
delete tp1;
|
||||
EXPECT_TRUE(finished_running);
|
||||
EXPECT_EQ(4, num_done.load());
|
||||
|
||||
delete blocking_op_state;
|
||||
blocking_op_state = nullptr;
|
||||
delete tp;
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
|
||||
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
|
||||
test::function::XTimes16()});
|
||||
@ -787,7 +855,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
|
||||
auto x4_x2_scale = ops::Const<float>(
|
||||
s.WithOpName("x4/x2/scale/_12__cf__6")
|
||||
s.WithOpName("x4/x2/scale/_12__cf__10")
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
2.0f);
|
||||
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
|
||||
@ -993,13 +1061,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
|
||||
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
|
||||
auto scale = ops::Const(
|
||||
s.WithOpName("scale/_6__cf__11")
|
||||
s.WithOpName("scale/_6__cf__15")
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
2.0f);
|
||||
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
|
||||
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
|
||||
auto const0 = ops::Const(
|
||||
s.WithOpName("Func/_1/sy/_5__cf__10")
|
||||
s.WithOpName("Func/_1/sy/_5__cf__14")
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
0, {0});
|
||||
auto func1_rx = ops::internal::BroadcastGradientArgs(
|
||||
@ -1247,14 +1315,14 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
opts.source_device = "/device:CPU:1";
|
||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y,
|
||||
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
|
||||
TensorShape({})));
|
||||
opts.remote_execution = true;
|
||||
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
|
||||
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
|
||||
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true));
|
||||
test::ExpectTensorEqual<string>(
|
||||
y,
|
||||
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
|
||||
|
@ -58,6 +58,59 @@ FunctionDef FindDevice() {
|
||||
{{{"device_name"}, "FindDeviceOp", {}, {}}});
|
||||
}
|
||||
|
||||
void BlockingOpState::AwaitState(int awaiting_state) {
|
||||
mutex_lock ml(mu_);
|
||||
while (state_ != awaiting_state) {
|
||||
cv_.wait(ml);
|
||||
}
|
||||
}
|
||||
|
||||
void BlockingOpState::MoveToState(int expected_current, int next) {
|
||||
mutex_lock ml(mu_);
|
||||
CHECK_EQ(expected_current, state_);
|
||||
state_ = next;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
BlockingOpState* blocking_op_state = nullptr;
|
||||
|
||||
// BlockingOp blocks on the global <blocking_op_state's> state,
|
||||
// and also updates it when it is unblocked and finishing computation.
|
||||
class BlockingOp : public OpKernel {
|
||||
public:
|
||||
explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
blocking_op_state->MoveToState(0, 1);
|
||||
blocking_op_state->AwaitState(2);
|
||||
blocking_op_state->MoveToState(2, 3);
|
||||
|
||||
Tensor* out = nullptr;
|
||||
const Tensor& in = ctx->input(0);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
|
||||
out->flat<float>() = in.flat<float>();
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
|
||||
REGISTER_OP("BlockingOp")
|
||||
.Input("x: float")
|
||||
.Output("y: float")
|
||||
.Doc("")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
FunctionDef BlockingOpFn() {
|
||||
return FDH::Define(
|
||||
// Name
|
||||
"BlockingOpFn",
|
||||
// Args
|
||||
{"x: float"},
|
||||
// Return values
|
||||
{"y: float"},
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{{{"y"}, "BlockingOp", {"x"}, {}}});
|
||||
}
|
||||
|
||||
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
|
||||
Output Call(Scope* scope, const string& op_name, const string& fn_name,
|
||||
gtl::ArraySlice<Input> inputs) {
|
||||
|
@ -25,6 +25,22 @@ namespace function {
|
||||
// {} -> y:DT_STRING (device where this op runs).
|
||||
FunctionDef FindDevice();
|
||||
|
||||
class BlockingOpState {
|
||||
public:
|
||||
void AwaitState(int awaiting_state);
|
||||
|
||||
void MoveToState(int expected_current, int next);
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
condition_variable cv_;
|
||||
int state_ = 0;
|
||||
};
|
||||
|
||||
extern BlockingOpState* blocking_op_state;
|
||||
|
||||
FunctionDef BlockingOpFn();
|
||||
|
||||
// Adds a function call to the given scope and returns the output for the node.
|
||||
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
|
||||
Output Call(Scope* scope, const string& op_name, const string& fn_name,
|
||||
|
@ -42,21 +42,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
DistributedFunctionLibraryRuntime* parent)
|
||||
: device_mgr_(device_mgr),
|
||||
lib_def_(lib_def),
|
||||
default_thread_pool_(default_thread_pool),
|
||||
next_handle_(0),
|
||||
parent_(parent) {
|
||||
if (device_mgr == nullptr) {
|
||||
flr_map_[nullptr] =
|
||||
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
|
||||
lib_def, optimizer_options, this);
|
||||
flr_map_[nullptr] = NewFunctionLibraryRuntime(
|
||||
nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
|
||||
optimizer_options, this);
|
||||
return;
|
||||
}
|
||||
for (Device* d : device_mgr->ListDevices()) {
|
||||
flr_map_[d] =
|
||||
NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
|
||||
lib_def, optimizer_options, this);
|
||||
flr_map_[d] = NewFunctionLibraryRuntime(
|
||||
device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
|
||||
optimizer_options, this);
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,21 +67,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
DistributedFunctionLibraryRuntime* parent)
|
||||
: device_mgr_(device_mgr),
|
||||
lib_def_(lib_def),
|
||||
default_thread_pool_(default_thread_pool),
|
||||
next_handle_(0),
|
||||
parent_(parent) {
|
||||
if (device_mgr == nullptr) {
|
||||
flr_map_[nullptr] = NewFunctionLibraryRuntime(
|
||||
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
|
||||
std::move(custom_kernel_creator), this);
|
||||
nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
|
||||
optimizer_options, std::move(custom_kernel_creator), this);
|
||||
return;
|
||||
}
|
||||
for (Device* d : device_mgr->ListDevices()) {
|
||||
flr_map_[d] = NewFunctionLibraryRuntime(
|
||||
device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
|
||||
custom_kernel_creator, this);
|
||||
device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
|
||||
optimizer_options, custom_kernel_creator, this);
|
||||
}
|
||||
}
|
||||
|
||||
@ -370,7 +374,8 @@ Status ProcessFunctionLibraryRuntime::Clone(
|
||||
out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
|
||||
out_pflr->reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_, env, graph_def_version, out_lib_def->get(),
|
||||
optimizer_options, std::move(custom_kernel_creator), parent_));
|
||||
optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
|
||||
parent_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -33,6 +33,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool = nullptr,
|
||||
DistributedFunctionLibraryRuntime* parent = nullptr);
|
||||
|
||||
// With `custom_kernel_creator`.
|
||||
@ -41,6 +42,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator,
|
||||
thread::ThreadPool* thread_pool,
|
||||
DistributedFunctionLibraryRuntime* parent);
|
||||
|
||||
// Sends `tensors_to_send` from `source_device` to `target_device` using
|
||||
@ -174,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
const DeviceMgr* const device_mgr_;
|
||||
const FunctionLibraryDefinition* lib_def_;
|
||||
thread::ThreadPool* default_thread_pool_;
|
||||
// Holds all the function invocations here.
|
||||
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
|
||||
GUARDED_BY(mu_);
|
||||
|
@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
cluster_flr_.reset(new TestClusterFLR());
|
||||
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, cluster_flr_.get()));
|
||||
opts, nullptr, cluster_flr_.get()));
|
||||
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||
}
|
||||
|
||||
@ -153,7 +153,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
|
||||
lib_def.get(), opts, nullptr /* cluster_flr */));
|
||||
lib_def.get(), opts, nullptr, nullptr /* cluster_flr */));
|
||||
FunctionLibraryRuntime* flr =
|
||||
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
EXPECT_NE(flr, nullptr);
|
||||
|
@ -134,7 +134,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
|
||||
item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_, worker_env_->env, gdef.versions().producer(),
|
||||
item->lib_def.get(), graph_options.optimizer_options(), cluster_flr));
|
||||
item->lib_def.get(), graph_options.optimizer_options(),
|
||||
worker_env_->compute_pool, cluster_flr));
|
||||
|
||||
// Constructs the graph out of "gdef".
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
Loading…
Reference in New Issue
Block a user