Change NumInterOpThreadsFromSessionOptions() to return value defined in the TF_NUM_INTEROP_THREADS when options.config.inter_op_parallelism_threads is negative, instead of crushing process.

Allow ConfigProto.inter_op_parallelism_threads to be negative, a negative value means all ops run in caller's thread.

PiperOrigin-RevId: 246897227
This commit is contained in:
A. Unique TensorFlower 2019-05-06 14:26:16 -07:00 committed by TensorFlower Gardener
parent ae53fb3375
commit 3b66ec3a38
6 changed files with 74 additions and 29 deletions

View File

@ -266,6 +266,14 @@ DirectSession::DirectSession(const SessionOptions& options,
true /* owned */);
} else {
thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
// Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
// and config.inter_op_parallelism_threads is unspecified or negative.
static const int env_num_threads = NumInterOpThreadsFromEnvironment();
if (options_.config.inter_op_parallelism_threads() < 0 ||
(options_.config.inter_op_parallelism_threads() == 0 &&
env_num_threads < 0)) {
run_in_caller_thread_ = true;
}
}
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
@ -566,6 +574,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
run_options.inter_op_thread_pool() >= 0
? thread_pools_[run_options.inter_op_thread_pool()].first
: nullptr;
if (run_in_caller_thread_) {
pool = nullptr;
}
if (pool == nullptr) {
// We allow using the caller thread only when having a single executor

View File

@ -401,6 +401,18 @@ class DirectSession : public Session {
mutex collective_graph_key_lock_;
int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
// Run in caller's thread if RunOptions.inter_op_thread_pool is negative or
// all of following conditions are met:
// 1. This session doesn't own any thread pool.
// 2. RunOptions.inter_op_thread_pool is unspecified or 0.
// 3. This session has a single executor.
// 4. config.inter_op_parallelism_threads is specified to negative explicitly
// or through environment variable TF_NUM_INTEROP_THREADS.
// 5. RunOptions.experimental.use_run_handler_pool is unspecified or false.
// Otherwise run in global thread pool, session owned thread pool or handler
// pool according to other specifications of RunOptions and ConfigProto.
bool run_in_caller_thread_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdbg) related

View File

@ -109,7 +109,7 @@ class DirectSessionMinusAXTest : public ::testing::Test {
z_ = z->name();
z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
test::graph::ToGraphDef(&graph, &def_);
graph.ToGraphDef(&def_);
}
string a_;
@ -540,7 +540,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
test::graph::ToGraphDef(&graph, &def);
graph.ToGraphDef(&def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
@ -552,7 +552,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
// Fix placement and run again
def.Clear();
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
test::graph::ToGraphDef(&graph, &def);
graph.ToGraphDef(&def);
session.reset(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<Tensor> outputs;
@ -671,7 +671,7 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
Node* init = test::graph::Assign(&g, var, twenty_node);
init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -706,7 +706,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
Node* second_const = test::graph::Constant(&g, second_value);
Node* second_identity = test::graph::Identity(&g, second_const);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -779,7 +779,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) {
Node* second_const = test::graph::Constant(&g, second_value);
Node* second_identity = test::graph::Identity(&g, second_const);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -865,7 +865,7 @@ TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
Node* y = test::graph::Add(&graph, left, right);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
graph.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -904,7 +904,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) {
Node* seven_node = test::graph::Constant(&g, seven_tensor);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -941,7 +941,7 @@ TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
Node* second_const = test::graph::Constant(&g, second_value);
Node* second_identity = test::graph::Identity(&g, second_const);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1031,7 +1031,7 @@ TEST(DirectSessionTest, SessionSyncRun) {
Node* x = test::graph::Constant(&g, vx);
Node* y = test::graph::Unary(&g, "ThreadID", x);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto sess = CreateSession();
TF_ASSERT_OK(sess->Create(def));
std::vector<Tensor> outputs;
@ -1044,6 +1044,27 @@ TEST(DirectSessionTest, SessionSyncRun) {
static_cast<int64>(outputs[0].scalar<int64>()()));
}
TEST(DirectSessionTest, SyncSession) {
Graph g(OpRegistry::Global());
Tensor vx(DT_INT64, TensorShape({}));
vx.scalar<int64>()() = 17;
Node* x = test::graph::Constant(&g, vx);
Node* y = test::graph::Unary(&g, "ThreadID", x);
GraphDef def;
g.ToGraphDef(&def);
SessionOptions options;
options.config.set_inter_op_parallelism_threads(-1);
std::unique_ptr<Session> sess(NewSession(options));
TF_ASSERT_OK(sess->Create(def));
std::vector<Tensor> outputs;
RunOptions run_opts;
auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
std::hash<std::thread::id> hasher;
EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
static_cast<int64>(outputs[0].scalar<int64>()()));
}
REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
Darth promises one return value.
@ -1066,7 +1087,7 @@ TEST(DirectSessionTest, DarthKernel) {
Node* x = test::graph::Constant(&g, vx);
Node* y = test::graph::Unary(&g, "Darth", x);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto sess = CreateSession();
TF_ASSERT_OK(sess->Create(def));
std::vector<Tensor> outputs;
@ -1084,7 +1105,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
Node* y = test::graph::Unary(&g, "Darth", x);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
// By default, we place the entire graph, so we should fail the
// call to Create.
@ -1102,7 +1123,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
Node* y = test::graph::Unary(&g, "Darth", x);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
SessionOptions options;
// Set the option to place pruned graphs, we should expect this
@ -1133,7 +1154,7 @@ TEST(DirectSessionTest, PartialRunTest) {
Node* third = test::graph::Add(&g, first_identity, second_identity);
Node* third_identity = test::graph::Identity(&g, third);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1189,7 +1210,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
Node* third = test::graph::Add(&g, first_identity, second_identity);
Node* third_identity = test::graph::Identity(&g, third);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1222,7 +1243,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1271,7 +1292,7 @@ TEST(DirectSessionTest, RunHandleTest) {
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1324,7 +1345,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1526,7 +1547,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
y = test::graph::Unary(&g, "BlockingOp", x);
}
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
*def.mutable_library() = library_graph_def;
// Create session with two inter-op thread pools.
@ -1677,7 +1698,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
t.scalar<float>()() = {1.2f};
Node* x = test::graph::Constant(&g, t);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
SessionOptions options;
options.config.mutable_graph_options()
@ -1736,7 +1757,7 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
Node* var = test::graph::Var(&g, DT_FLOAT, {});
Node* var_assign = test::graph::Assign(&g, var, var_val);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
@ -1790,7 +1811,7 @@ TEST(DirectSessionTest, TestDirectSessionPRunClose) {
Node* third = test::graph::Add(&g, first_identity, second_identity);
Node* third_identity = test::graph::Identity(&g, third);
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
@ -1829,7 +1850,7 @@ TEST(DirectSessionTest, TestDirectSessionReset) {
Node* var = test::graph::Var(&g, DT_FLOAT, {});
Node* var_assign = test::graph::Assign(&g, var, var_val);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
g.ToGraphDef(&def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;

View File

@ -99,7 +99,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 inter_op = options.config.inter_op_parallelism_threads();
if (inter_op != 0) return inter_op;
if (inter_op > 0) return inter_op;
#ifdef INTEL_MKL
if (!DisableMKL()) {
// MKL library executes ops in parallel using OMP threads

View File

@ -37,11 +37,11 @@ int32 NumInterOpThreadsFromEnvironment();
int32 NumIntraOpThreadsFromEnvironment();
// Returns the number of inter op threads specified in `options` or a default.
// If no value is specified in the provided options, then the function returns
// the value defined in the TF_NUM_INTEROP_THREADS environment variable.
// If neither a value is specified in the options or in the environment,
// this function will return a reasonable default value based on the number
// of schedulable CPUs, and any MKL and OpenMP configurations.
// If no value or a negative value is specified in the provided options, then
// the function returns the value defined in the TF_NUM_INTEROP_THREADS
// environment variable. If neither a value is specified in the options or in
// the environment, this function will return a reasonable default value based
// on the number of schedulable CPUs, and any MKL and OpenMP configurations.
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options);
// Creates a thread pool with number of inter op threads.

View File

@ -347,6 +347,7 @@ message ConfigProto {
// inter_op_parallelism_threads available in each process.
//
// 0 means the system picks an appropriate number.
// Negative means all operations are performed in caller's thread.
//
// Note that the first Session created in the process sets the
// number of threads for all future sessions unless use_per_session_threads is