Merge pull request from trentlo:no_mem_opt_if_jit_on

PiperOrigin-RevId: 270071858
This commit is contained in:
TensorFlower Gardener 2019-09-19 10:46:18 -07:00
commit caa5a8ea6b
7 changed files with 171 additions and 12 deletions

View File

@ -629,6 +629,7 @@ cc_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
@ -637,7 +638,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/xla_config_registry.h"
namespace tensorflow {
@ -218,19 +219,12 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
namespace {
struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu;
OptimizerOptions::GlobalJitLevel general;
};
typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
XlaGlobalJitLevel GetXlaGlobalJitLevel(
const GraphOptimizationPassOptions& options) {
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
XlaGlobalJitLevel result;
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
options.session_options->config.graph_options()
.optimizer_options()
.global_jit_level();
if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
// To set compilation to be on by default, change the following line.
result.single_gpu = result.general = OptimizerOptions::OFF;
@ -289,7 +283,12 @@ bool IsSingleGpuGraph(const Graph& g) {
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
const GraphOptimizationPassOptions& options) {
XlaGlobalJitLevel xla_global_jit_level = GetXlaGlobalJitLevel(options);
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
options.session_options->config.graph_options()
.optimizer_options()
.global_jit_level();
XlaGlobalJitLevel xla_global_jit_level =
GetXlaGlobalJitLevel(jit_level_in_session_opts);
if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
VLOG(4) << "GetGlobalJitLevelForGraph returning "
<< xla_global_jit_level.single_gpu;
@ -386,4 +385,8 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
return result;
}
// Register a callback for querying XlaGlobalJitLevel.
REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
} // namespace tensorflow

View File

@ -2810,6 +2810,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
"util/tensor_slice_util.h",
"util/xla_config_registry.h",
]
tf_cuda_library(

View File

@ -605,6 +605,7 @@ cc_library(
":shape_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -617,6 +618,7 @@ cc_library(
"//tensorflow/core/grappler/utils:tpu",
"//tensorflow/core/grappler/verifiers:graph_verifier",
"//tensorflow/core/grappler/verifiers:structure_verifier",
"//tensorflow/core/lib/gtl:map_util",
"@com_google_absl//absl/strings",
],
)

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/xla_config_registry.h"
namespace tensorflow {
namespace grappler {
@ -126,6 +127,38 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
return false;
}
bool IsXlaGlobalJitOn(
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
// Return true only if XLA JIT is ON for both single-gpu and multi-gpu
// graphs. This is a conservative approach that turns off the memory optimizer
// when we are sure that all graphs will be processed by XLA JIT.
bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
(xla_global_jit_level.general == OptimizerOptions::ON_1 ||
xla_global_jit_level.general == OptimizerOptions::ON_2);
return is_on;
}
// A helper function to decide whether to enable the memory optimizer.
bool MemoryOptimizerEnabled(
RewriterConfig::MemOptType mem_opt_type,
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
// Disable the default memory optimizer when XLA JIT is ON as it hurts the
// XLA JIT performance. The (current) XLA clustering can result in loss of
// concurrency between kernel compute and memory copies. As such, it usually
// loses the concurrency needed to hide the latencies of the inserted swap-ins
// and swap-outs and incurs great performance overhead. Remove this check when
// the XLA JIT can better deal with the concurrency.
if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
IsXlaGlobalJitOn(jit_level_in_session_opts)) {
return false;
}
return mem_opt_type != RewriterConfig::NO_MEM_OPT;
}
} // namespace
#define MK_OPT(NAME, VALUE) \
@ -216,7 +249,9 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(
MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
}
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
auto global_jit_level =
config_proto_.graph_options().optimizer_options().global_jit_level();
if (MemoryOptimizerEnabled(cfg_.memory_optimization(), global_jit_level)) {
if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
optimizers->push_back(
// Use the default target node name prefix "gradients/"

View File

@ -0,0 +1,55 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/xla_config_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace xla_config_registry {
namespace {
struct GlobalJitLevelState {
mutex mu;
GlobalJitLevelGetterTy getter GUARDED_BY(mu);
};
GlobalJitLevelState* GetSingletonState() {
static GlobalJitLevelState* state = new GlobalJitLevelState;
return state;
}
} // namespace
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) {
GlobalJitLevelState* state = GetSingletonState();
mutex_lock l(state->mu);
CHECK(!state->getter);
state->getter = std::move(getter);
}
XlaGlobalJitLevel GetGlobalJitLevel(
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
GlobalJitLevelState* state = GetSingletonState();
mutex_lock l(state->mu);
if (!state->getter) {
return {jit_level_in_session_opts, jit_level_in_session_opts};
}
return state->getter(jit_level_in_session_opts);
}
} // namespace xla_config_registry
} // namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
#define TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
#include <functional>
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace xla_config_registry {
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
// single gpu and general (multi-gpu) graphs.
struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu;
OptimizerOptions::GlobalJitLevel general;
};
// Input is the jit_level in session config, and return value is the jit_level
// from XLA, reflecting the effect of the environment variable flags.
typedef std::function<XlaGlobalJitLevel(
const OptimizerOptions::GlobalJitLevel&)>
GlobalJitLevelGetterTy;
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter);
XlaGlobalJitLevel GetGlobalJitLevel(
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts);
#define REGISTER_XLA_CONFIG_GETTER(getter) \
REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
static bool xla_config_registry_registration_##ctr = \
(::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \
getter), \
true)
} // namespace xla_config_registry
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_