Change NumInterOpThreadsFromSessionOptions() to return value defined in the TF_NUM_INTEROP_THREADS when options.config.inter_op_parallelism_threads is negative, instead of crushing process.
Allow ConfigProto.inter_op_parallelism_threads to be negative, a negative value means all ops run in caller's thread. PiperOrigin-RevId: 246897227
This commit is contained in:
parent
ae53fb3375
commit
3b66ec3a38
@ -266,6 +266,14 @@ DirectSession::DirectSession(const SessionOptions& options,
|
||||
true /* owned */);
|
||||
} else {
|
||||
thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
|
||||
// Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
|
||||
// and config.inter_op_parallelism_threads is unspecified or negative.
|
||||
static const int env_num_threads = NumInterOpThreadsFromEnvironment();
|
||||
if (options_.config.inter_op_parallelism_threads() < 0 ||
|
||||
(options_.config.inter_op_parallelism_threads() == 0 &&
|
||||
env_num_threads < 0)) {
|
||||
run_in_caller_thread_ = true;
|
||||
}
|
||||
}
|
||||
// The default value of sync_on_finish will be flipped soon and this
|
||||
// environment variable will be removed as well.
|
||||
@ -566,6 +574,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
|
||||
run_options.inter_op_thread_pool() >= 0
|
||||
? thread_pools_[run_options.inter_op_thread_pool()].first
|
||||
: nullptr;
|
||||
if (run_in_caller_thread_) {
|
||||
pool = nullptr;
|
||||
}
|
||||
|
||||
if (pool == nullptr) {
|
||||
// We allow using the caller thread only when having a single executor
|
||||
|
@ -401,6 +401,18 @@ class DirectSession : public Session {
|
||||
mutex collective_graph_key_lock_;
|
||||
int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
|
||||
|
||||
// Run in caller's thread if RunOptions.inter_op_thread_pool is negative or
|
||||
// all of following conditions are met:
|
||||
// 1. This session doesn't own any thread pool.
|
||||
// 2. RunOptions.inter_op_thread_pool is unspecified or 0.
|
||||
// 3. This session has a single executor.
|
||||
// 4. config.inter_op_parallelism_threads is specified to negative explicitly
|
||||
// or through environment variable TF_NUM_INTEROP_THREADS.
|
||||
// 5. RunOptions.experimental.use_run_handler_pool is unspecified or false.
|
||||
// Otherwise run in global thread pool, session owned thread pool or handler
|
||||
// pool according to other specifications of RunOptions and ConfigProto.
|
||||
bool run_in_caller_thread_ = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
||||
|
||||
// EXPERIMENTAL: debugger (tfdbg) related
|
||||
|
@ -109,7 +109,7 @@ class DirectSessionMinusAXTest : public ::testing::Test {
|
||||
z_ = z->name();
|
||||
z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
|
||||
test::graph::ToGraphDef(&graph, &def_);
|
||||
graph.ToGraphDef(&def_);
|
||||
}
|
||||
|
||||
string a_;
|
||||
@ -540,7 +540,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
|
||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
|
||||
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
graph.ToGraphDef(&def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
@ -552,7 +552,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
|
||||
// Fix placement and run again
|
||||
def.Clear();
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
graph.ToGraphDef(&def);
|
||||
session.reset(NewSession(options));
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
std::vector<Tensor> outputs;
|
||||
@ -671,7 +671,7 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
|
||||
Node* init = test::graph::Assign(&g, var, twenty_node);
|
||||
init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -706,7 +706,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -779,7 +779,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) {
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -865,7 +865,7 @@ TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
|
||||
Node* y = test::graph::Add(&graph, left, right);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
graph.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -904,7 +904,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) {
|
||||
Node* seven_node = test::graph::Constant(&g, seven_tensor);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -941,7 +941,7 @@ TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1031,7 +1031,7 @@ TEST(DirectSessionTest, SessionSyncRun) {
|
||||
Node* x = test::graph::Constant(&g, vx);
|
||||
Node* y = test::graph::Unary(&g, "ThreadID", x);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
auto sess = CreateSession();
|
||||
TF_ASSERT_OK(sess->Create(def));
|
||||
std::vector<Tensor> outputs;
|
||||
@ -1044,6 +1044,27 @@ TEST(DirectSessionTest, SessionSyncRun) {
|
||||
static_cast<int64>(outputs[0].scalar<int64>()()));
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, SyncSession) {
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor vx(DT_INT64, TensorShape({}));
|
||||
vx.scalar<int64>()() = 17;
|
||||
Node* x = test::graph::Constant(&g, vx);
|
||||
Node* y = test::graph::Unary(&g, "ThreadID", x);
|
||||
GraphDef def;
|
||||
g.ToGraphDef(&def);
|
||||
SessionOptions options;
|
||||
options.config.set_inter_op_parallelism_threads(-1);
|
||||
std::unique_ptr<Session> sess(NewSession(options));
|
||||
TF_ASSERT_OK(sess->Create(def));
|
||||
std::vector<Tensor> outputs;
|
||||
RunOptions run_opts;
|
||||
auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
|
||||
|
||||
std::hash<std::thread::id> hasher;
|
||||
EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
|
||||
static_cast<int64>(outputs[0].scalar<int64>()()));
|
||||
}
|
||||
|
||||
REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
|
||||
Darth promises one return value.
|
||||
|
||||
@ -1066,7 +1087,7 @@ TEST(DirectSessionTest, DarthKernel) {
|
||||
Node* x = test::graph::Constant(&g, vx);
|
||||
Node* y = test::graph::Unary(&g, "Darth", x);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
auto sess = CreateSession();
|
||||
TF_ASSERT_OK(sess->Create(def));
|
||||
std::vector<Tensor> outputs;
|
||||
@ -1084,7 +1105,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
|
||||
Node* y = test::graph::Unary(&g, "Darth", x);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
// By default, we place the entire graph, so we should fail the
|
||||
// call to Create.
|
||||
@ -1102,7 +1123,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
|
||||
Node* y = test::graph::Unary(&g, "Darth", x);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
SessionOptions options;
|
||||
// Set the option to place pruned graphs, we should expect this
|
||||
@ -1133,7 +1154,7 @@ TEST(DirectSessionTest, PartialRunTest) {
|
||||
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||
Node* third_identity = test::graph::Identity(&g, third);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1189,7 +1210,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
|
||||
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||
Node* third_identity = test::graph::Identity(&g, third);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1222,7 +1243,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
|
||||
Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
|
||||
Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1271,7 +1292,7 @@ TEST(DirectSessionTest, RunHandleTest) {
|
||||
|
||||
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1324,7 +1345,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
|
||||
|
||||
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1526,7 +1547,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
|
||||
y = test::graph::Unary(&g, "BlockingOp", x);
|
||||
}
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
*def.mutable_library() = library_graph_def;
|
||||
|
||||
// Create session with two inter-op thread pools.
|
||||
@ -1677,7 +1698,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
|
||||
t.scalar<float>()() = {1.2f};
|
||||
Node* x = test::graph::Constant(&g, t);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
SessionOptions options;
|
||||
options.config.mutable_graph_options()
|
||||
@ -1736,7 +1757,7 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
@ -1790,7 +1811,7 @@ TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||
Node* third_identity = test::graph::Identity(&g, third);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -1829,7 +1850,7 @@ TEST(DirectSessionTest, TestDirectSessionReset) {
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
g.ToGraphDef(&def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
|
@ -99,7 +99,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
|
||||
|
||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
||||
if (inter_op != 0) return inter_op;
|
||||
if (inter_op > 0) return inter_op;
|
||||
#ifdef INTEL_MKL
|
||||
if (!DisableMKL()) {
|
||||
// MKL library executes ops in parallel using OMP threads
|
||||
|
@ -37,11 +37,11 @@ int32 NumInterOpThreadsFromEnvironment();
|
||||
int32 NumIntraOpThreadsFromEnvironment();
|
||||
|
||||
// Returns the number of inter op threads specified in `options` or a default.
|
||||
// If no value is specified in the provided options, then the function returns
|
||||
// the value defined in the TF_NUM_INTEROP_THREADS environment variable.
|
||||
// If neither a value is specified in the options or in the environment,
|
||||
// this function will return a reasonable default value based on the number
|
||||
// of schedulable CPUs, and any MKL and OpenMP configurations.
|
||||
// If no value or a negative value is specified in the provided options, then
|
||||
// the function returns the value defined in the TF_NUM_INTEROP_THREADS
|
||||
// environment variable. If neither a value is specified in the options or in
|
||||
// the environment, this function will return a reasonable default value based
|
||||
// on the number of schedulable CPUs, and any MKL and OpenMP configurations.
|
||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options);
|
||||
|
||||
// Creates a thread pool with number of inter op threads.
|
||||
|
@ -347,6 +347,7 @@ message ConfigProto {
|
||||
// inter_op_parallelism_threads available in each process.
|
||||
//
|
||||
// 0 means the system picks an appropriate number.
|
||||
// Negative means all operations are performed in caller's thread.
|
||||
//
|
||||
// Note that the first Session created in the process sets the
|
||||
// number of threads for all future sessions unless use_per_session_threads is
|
||||
|
Loading…
Reference in New Issue
Block a user