Merge pull request #32245 from trentlo:no_mem_opt_if_jit_on
PiperOrigin-RevId: 270071858
This commit is contained in:
commit
caa5a8ea6b
tensorflow
compiler/jit
core
@ -629,6 +629,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
@ -637,7 +638,6 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/xla_config_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -218,19 +219,12 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||
|
||||
namespace {
|
||||
struct XlaGlobalJitLevel {
|
||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||
OptimizerOptions::GlobalJitLevel general;
|
||||
};
|
||||
typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
|
||||
|
||||
XlaGlobalJitLevel GetXlaGlobalJitLevel(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||
XlaGlobalJitLevel result;
|
||||
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
|
||||
options.session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
|
||||
// To set compilation to be on by default, change the following line.
|
||||
result.single_gpu = result.general = OptimizerOptions::OFF;
|
||||
@ -289,7 +283,12 @@ bool IsSingleGpuGraph(const Graph& g) {
|
||||
|
||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
XlaGlobalJitLevel xla_global_jit_level = GetXlaGlobalJitLevel(options);
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
|
||||
options.session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
XlaGlobalJitLevel xla_global_jit_level =
|
||||
GetXlaGlobalJitLevel(jit_level_in_session_opts);
|
||||
if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
|
||||
VLOG(4) << "GetGlobalJitLevelForGraph returning "
|
||||
<< xla_global_jit_level.single_gpu;
|
||||
@ -386,4 +385,8 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Register a callback for querying XlaGlobalJitLevel.
|
||||
REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -2810,6 +2810,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
|
||||
"util/presized_cuckoo_map.h",
|
||||
"util/tensor_slice_set.h",
|
||||
"util/tensor_slice_util.h",
|
||||
"util/xla_config_registry.h",
|
||||
]
|
||||
|
||||
tf_cuda_library(
|
||||
|
@ -605,6 +605,7 @@ cc_library(
|
||||
":shape_optimizer",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -617,6 +618,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler/utils:tpu",
|
||||
"//tensorflow/core/grappler/verifiers:graph_verifier",
|
||||
"//tensorflow/core/grappler/verifiers:structure_verifier",
|
||||
"//tensorflow/core/lib/gtl:map_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
#include "tensorflow/core/util/xla_config_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -126,6 +127,38 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsXlaGlobalJitOn(
|
||||
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||
xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
|
||||
xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
|
||||
// Return true only if XLA JIT is ON for both single-gpu and multi-gpu
|
||||
// graphs. This is a conservative approach that turns off the memory optimizer
|
||||
// when we are sure that all graphs will be processed by XLA JIT.
|
||||
bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
|
||||
xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
|
||||
(xla_global_jit_level.general == OptimizerOptions::ON_1 ||
|
||||
xla_global_jit_level.general == OptimizerOptions::ON_2);
|
||||
return is_on;
|
||||
}
|
||||
|
||||
// A helper function to decide whether to enable the memory optimizer.
|
||||
bool MemoryOptimizerEnabled(
|
||||
RewriterConfig::MemOptType mem_opt_type,
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||
// Disable the default memory optimizer when XLA JIT is ON as it hurts the
|
||||
// XLA JIT performance. The (current) XLA clustering can result in loss of
|
||||
// concurrency between kernel compute and memory copies. As such, it usually
|
||||
// loses the concurrency needed to hide the latencies of the inserted swap-ins
|
||||
// and swap-outs and incurs great performance overhead. Remove this check when
|
||||
// the XLA JIT can better deal with the concurrency.
|
||||
if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
|
||||
IsXlaGlobalJitOn(jit_level_in_session_opts)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return mem_opt_type != RewriterConfig::NO_MEM_OPT;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define MK_OPT(NAME, VALUE) \
|
||||
@ -216,7 +249,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
optimizers->push_back(
|
||||
MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
|
||||
}
|
||||
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
|
||||
auto global_jit_level =
|
||||
config_proto_.graph_options().optimizer_options().global_jit_level();
|
||||
if (MemoryOptimizerEnabled(cfg_.memory_optimization(), global_jit_level)) {
|
||||
if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
|
||||
optimizers->push_back(
|
||||
// Use the default target node name prefix "gradients/"
|
||||
|
55
tensorflow/core/util/xla_config_registry.cc
Normal file
55
tensorflow/core/util/xla_config_registry.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/xla_config_registry.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace xla_config_registry {
|
||||
|
||||
namespace {
|
||||
struct GlobalJitLevelState {
|
||||
mutex mu;
|
||||
GlobalJitLevelGetterTy getter GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
GlobalJitLevelState* GetSingletonState() {
|
||||
static GlobalJitLevelState* state = new GlobalJitLevelState;
|
||||
return state;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) {
|
||||
GlobalJitLevelState* state = GetSingletonState();
|
||||
mutex_lock l(state->mu);
|
||||
CHECK(!state->getter);
|
||||
state->getter = std::move(getter);
|
||||
}
|
||||
|
||||
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||
GlobalJitLevelState* state = GetSingletonState();
|
||||
mutex_lock l(state->mu);
|
||||
if (!state->getter) {
|
||||
return {jit_level_in_session_opts, jit_level_in_session_opts};
|
||||
}
|
||||
return state->getter(jit_level_in_session_opts);
|
||||
}
|
||||
|
||||
} // namespace xla_config_registry
|
||||
|
||||
} // namespace tensorflow
|
63
tensorflow/core/util/xla_config_registry.h
Normal file
63
tensorflow/core/util/xla_config_registry.h
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/core/framework/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace xla_config_registry {
|
||||
|
||||
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
|
||||
// single gpu and general (multi-gpu) graphs.
|
||||
struct XlaGlobalJitLevel {
|
||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||
OptimizerOptions::GlobalJitLevel general;
|
||||
};
|
||||
|
||||
// Input is the jit_level in session config, and return value is the jit_level
|
||||
// from XLA, reflecting the effect of the environment variable flags.
|
||||
typedef std::function<XlaGlobalJitLevel(
|
||||
const OptimizerOptions::GlobalJitLevel&)>
|
||||
GlobalJitLevelGetterTy;
|
||||
|
||||
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter);
|
||||
|
||||
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts);
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER(getter) \
|
||||
REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
|
||||
REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
|
||||
static bool xla_config_registry_registration_##ctr = \
|
||||
(::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \
|
||||
getter), \
|
||||
true)
|
||||
|
||||
} // namespace xla_config_registry
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
Loading…
Reference in New Issue
Block a user