Merge pull request #32245 from trentlo:no_mem_opt_if_jit_on
PiperOrigin-RevId: 270071858
This commit is contained in:
commit
caa5a8ea6b
@ -629,6 +629,7 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_bounds_check",
|
"//tensorflow/core:framework_bounds_check",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
@ -637,7 +638,6 @@ cc_library(
|
|||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/control_flow.h"
|
#include "tensorflow/core/graph/control_flow.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
#include "tensorflow/core/util/xla_config_registry.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -218,19 +219,12 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
|||||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct XlaGlobalJitLevel {
|
typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
|
||||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
|
||||||
OptimizerOptions::GlobalJitLevel general;
|
|
||||||
};
|
|
||||||
|
|
||||||
XlaGlobalJitLevel GetXlaGlobalJitLevel(
|
XlaGlobalJitLevel GetXlaGlobalJitLevel(
|
||||||
const GraphOptimizationPassOptions& options) {
|
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||||
XlaGlobalJitLevel result;
|
XlaGlobalJitLevel result;
|
||||||
|
|
||||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
|
|
||||||
options.session_options->config.graph_options()
|
|
||||||
.optimizer_options()
|
|
||||||
.global_jit_level();
|
|
||||||
if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
|
if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
|
||||||
// To set compilation to be on by default, change the following line.
|
// To set compilation to be on by default, change the following line.
|
||||||
result.single_gpu = result.general = OptimizerOptions::OFF;
|
result.single_gpu = result.general = OptimizerOptions::OFF;
|
||||||
@ -289,7 +283,12 @@ bool IsSingleGpuGraph(const Graph& g) {
|
|||||||
|
|
||||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
||||||
const GraphOptimizationPassOptions& options) {
|
const GraphOptimizationPassOptions& options) {
|
||||||
XlaGlobalJitLevel xla_global_jit_level = GetXlaGlobalJitLevel(options);
|
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
|
||||||
|
options.session_options->config.graph_options()
|
||||||
|
.optimizer_options()
|
||||||
|
.global_jit_level();
|
||||||
|
XlaGlobalJitLevel xla_global_jit_level =
|
||||||
|
GetXlaGlobalJitLevel(jit_level_in_session_opts);
|
||||||
if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
|
if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
|
||||||
VLOG(4) << "GetGlobalJitLevelForGraph returning "
|
VLOG(4) << "GetGlobalJitLevelForGraph returning "
|
||||||
<< xla_global_jit_level.single_gpu;
|
<< xla_global_jit_level.single_gpu;
|
||||||
@ -386,4 +385,8 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register a callback for querying XlaGlobalJitLevel.
|
||||||
|
REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -2810,6 +2810,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
|
|||||||
"util/presized_cuckoo_map.h",
|
"util/presized_cuckoo_map.h",
|
||||||
"util/tensor_slice_set.h",
|
"util/tensor_slice_set.h",
|
||||||
"util/tensor_slice_util.h",
|
"util/tensor_slice_util.h",
|
||||||
|
"util/xla_config_registry.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
|
|||||||
@ -605,6 +605,7 @@ cc_library(
|
|||||||
":shape_optimizer",
|
":shape_optimizer",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -617,6 +618,7 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler/utils:tpu",
|
"//tensorflow/core/grappler/utils:tpu",
|
||||||
"//tensorflow/core/grappler/verifiers:graph_verifier",
|
"//tensorflow/core/grappler/verifiers:graph_verifier",
|
||||||
"//tensorflow/core/grappler/verifiers:structure_verifier",
|
"//tensorflow/core/grappler/verifiers:structure_verifier",
|
||||||
|
"//tensorflow/core/lib/gtl:map_util",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -51,6 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
#include "tensorflow/core/util/xla_config_registry.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -126,6 +127,38 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsXlaGlobalJitOn(
|
||||||
|
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||||
|
xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
|
||||||
|
xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
|
||||||
|
// Return true only if XLA JIT is ON for both single-gpu and multi-gpu
|
||||||
|
// graphs. This is a conservative approach that turns off the memory optimizer
|
||||||
|
// when we are sure that all graphs will be processed by XLA JIT.
|
||||||
|
bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
|
||||||
|
xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
|
||||||
|
(xla_global_jit_level.general == OptimizerOptions::ON_1 ||
|
||||||
|
xla_global_jit_level.general == OptimizerOptions::ON_2);
|
||||||
|
return is_on;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A helper function to decide whether to enable the memory optimizer.
|
||||||
|
bool MemoryOptimizerEnabled(
|
||||||
|
RewriterConfig::MemOptType mem_opt_type,
|
||||||
|
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||||
|
// Disable the default memory optimizer when XLA JIT is ON as it hurts the
|
||||||
|
// XLA JIT performance. The (current) XLA clustering can result in loss of
|
||||||
|
// concurrency between kernel compute and memory copies. As such, it usually
|
||||||
|
// loses the concurrency needed to hide the latencies of the inserted swap-ins
|
||||||
|
// and swap-outs and incurs great performance overhead. Remove this check when
|
||||||
|
// the XLA JIT can better deal with the concurrency.
|
||||||
|
if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
|
||||||
|
IsXlaGlobalJitOn(jit_level_in_session_opts)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mem_opt_type != RewriterConfig::NO_MEM_OPT;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define MK_OPT(NAME, VALUE) \
|
#define MK_OPT(NAME, VALUE) \
|
||||||
@ -216,7 +249,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
|
MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
|
||||||
}
|
}
|
||||||
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
|
auto global_jit_level =
|
||||||
|
config_proto_.graph_options().optimizer_options().global_jit_level();
|
||||||
|
if (MemoryOptimizerEnabled(cfg_.memory_optimization(), global_jit_level)) {
|
||||||
if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
|
if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
|
||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
// Use the default target node name prefix "gradients/"
|
// Use the default target node name prefix "gradients/"
|
||||||
|
|||||||
55
tensorflow/core/util/xla_config_registry.cc
Normal file
55
tensorflow/core/util/xla_config_registry.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/xla_config_registry.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace xla_config_registry {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct GlobalJitLevelState {
|
||||||
|
mutex mu;
|
||||||
|
GlobalJitLevelGetterTy getter GUARDED_BY(mu);
|
||||||
|
};
|
||||||
|
|
||||||
|
GlobalJitLevelState* GetSingletonState() {
|
||||||
|
static GlobalJitLevelState* state = new GlobalJitLevelState;
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) {
|
||||||
|
GlobalJitLevelState* state = GetSingletonState();
|
||||||
|
mutex_lock l(state->mu);
|
||||||
|
CHECK(!state->getter);
|
||||||
|
state->getter = std::move(getter);
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||||
|
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||||
|
GlobalJitLevelState* state = GetSingletonState();
|
||||||
|
mutex_lock l(state->mu);
|
||||||
|
if (!state->getter) {
|
||||||
|
return {jit_level_in_session_opts, jit_level_in_session_opts};
|
||||||
|
}
|
||||||
|
return state->getter(jit_level_in_session_opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_config_registry
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
63
tensorflow/core/util/xla_config_registry.h
Normal file
63
tensorflow/core/util/xla_config_registry.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
||||||
|
#define TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace xla_config_registry {
|
||||||
|
|
||||||
|
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
|
||||||
|
// single gpu and general (multi-gpu) graphs.
|
||||||
|
struct XlaGlobalJitLevel {
|
||||||
|
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||||
|
OptimizerOptions::GlobalJitLevel general;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Input is the jit_level in session config, and return value is the jit_level
|
||||||
|
// from XLA, reflecting the effect of the environment variable flags.
|
||||||
|
typedef std::function<XlaGlobalJitLevel(
|
||||||
|
const OptimizerOptions::GlobalJitLevel&)>
|
||||||
|
GlobalJitLevelGetterTy;
|
||||||
|
|
||||||
|
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter);
|
||||||
|
|
||||||
|
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||||
|
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts);
|
||||||
|
|
||||||
|
#define REGISTER_XLA_CONFIG_GETTER(getter) \
|
||||||
|
REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
|
||||||
|
|
||||||
|
#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
|
||||||
|
REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
|
||||||
|
|
||||||
|
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
|
||||||
|
static bool xla_config_registry_registration_##ctr = \
|
||||||
|
(::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \
|
||||||
|
getter), \
|
||||||
|
true)
|
||||||
|
|
||||||
|
} // namespace xla_config_registry
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user