diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 73e0ccf63bb..3863bcf3131 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -219,7 +219,7 @@ void RemoveFromXlaCluster(NodeDef* node_def) { void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } namespace { -typedef XlaConfigRegistry::XlaGlobalJitLevel XlaGlobalJitLevel; +typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel; XlaGlobalJitLevel GetXlaGlobalJitLevel( const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c01818e370d..f0ded3b635e 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -129,8 +129,8 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) { bool IsXlaGlobalJitOn( const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) { - XlaConfigRegistry::XlaGlobalJitLevel xla_global_jit_level = - XlaConfigRegistry::GetGlobalJitLevel(jit_level_in_session_opts); + xla_config_registry::XlaGlobalJitLevel xla_global_jit_level = + xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts); // Return true only if XLA JIT is ON for both single-gpu and multi-gpu // graphs. This is a conservative approach that turns off the memory optimizer // when we are sure that all graphs will be processed by XLA JIT. diff --git a/tensorflow/core/util/xla_config_registry.cc b/tensorflow/core/util/xla_config_registry.cc index 3028845895c..a3270620c02 100644 --- a/tensorflow/core/util/xla_config_registry.cc +++ b/tensorflow/core/util/xla_config_registry.cc @@ -17,11 +17,37 @@ limitations under the License. namespace tensorflow { -/*static*/ -mutex XlaConfigRegistry::mu_(LINKER_INITIALIZED); +namespace xla_config_registry { -/*static*/ -XlaConfigRegistry::GlobalJitLevelGetterTy - XlaConfigRegistry::global_jit_level_getter_; +namespace { +struct GlobalJitLevelState { + mutex mu; + GlobalJitLevelGetterTy getter; +}; + +GlobalJitLevelState* GetSingletonState() { + static GlobalJitLevelState* state = new GlobalJitLevelState; + return state; +} +} // anonymous + +void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) { + GlobalJitLevelState* state = GetSingletonState(); + mutex_lock l(state->mu); + CHECK(!state->getter); + state->getter = std::move(getter); +} + +XlaGlobalJitLevel GetGlobalJitLevel( + OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) { + GlobalJitLevelState* state = GetSingletonState(); + mutex_lock l(state->mu); + if (!state->getter) { + return {jit_level_in_session_opts, jit_level_in_session_opts}; + } + return state->getter(jit_level_in_session_opts); +} + +} // namespace xla_config_registry } // namespace tensorflow diff --git a/tensorflow/core/util/xla_config_registry.h b/tensorflow/core/util/xla_config_registry.h index b947396e9c1..a7ad8d9fb61 100644 --- a/tensorflow/core/util/xla_config_registry.h +++ b/tensorflow/core/util/xla_config_registry.h @@ -23,52 +23,39 @@ limitations under the License. namespace tensorflow { -// A registry class where XLA can register callbacks for Tensorflow to query -// its status. -class XlaConfigRegistry { - public: - // XlaGlobalJitLevel is used by XLA to expose its JIT level for processing - // single gpu and general (multi-gpu) graphs. - struct XlaGlobalJitLevel { - OptimizerOptions::GlobalJitLevel single_gpu; - OptimizerOptions::GlobalJitLevel general; - }; +namespace xla_config_registry { - // Input is the jit_level in session config, and return value is the jit_level - // from XLA, reflecting the effect of the environment variable flags. - typedef std::function - GlobalJitLevelGetterTy; - - static void Register(XlaConfigRegistry::GlobalJitLevelGetterTy getter) { - mutex_lock l(mu_); - CHECK(!global_jit_level_getter_); - global_jit_level_getter_ = std::move(getter); - } - - static XlaGlobalJitLevel GetGlobalJitLevel( - OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) { - mutex_lock l(mu_); - if (!global_jit_level_getter_) { - return {jit_level_in_session_opts, jit_level_in_session_opts}; - } - return global_jit_level_getter_(jit_level_in_session_opts); - } - - private: - static mutex mu_; - static GlobalJitLevelGetterTy global_jit_level_getter_ GUARDED_BY(mu_); +// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing +// single gpu and general (multi-gpu) graphs. +struct XlaGlobalJitLevel { + OptimizerOptions::GlobalJitLevel single_gpu; + OptimizerOptions::GlobalJitLevel general; }; +// Input is the jit_level in session config, and return value is the jit_level +// from XLA, reflecting the effect of the environment variable flags. +typedef std::function + GlobalJitLevelGetterTy; + +void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter); + +XlaGlobalJitLevel GetGlobalJitLevel( + OptimizerOptions::GlobalJitLevel jit_level_in_session_opts); + #define REGISTER_XLA_CONFIG_GETTER(getter) \ REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter) #define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \ REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) -#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \ - static bool xla_config_registry_registration_##ctr = \ - (::tensorflow::XlaConfigRegistry::Register(getter), true) +#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \ + static bool xla_config_registry_registration_##ctr = \ + (::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \ + getter), \ + true) + +} // namespace xla_config_registry } // namespace tensorflow