Use new
instead of static object to avoid destructor.
By avoiding destructor, it should work better with multi-threaded program.
This commit is contained in:
parent
a908920e93
commit
1d997791dc
@ -219,7 +219,7 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||
|
||||
namespace {
|
||||
typedef XlaConfigRegistry::XlaGlobalJitLevel XlaGlobalJitLevel;
|
||||
typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
|
||||
|
||||
XlaGlobalJitLevel GetXlaGlobalJitLevel(
|
||||
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||
|
@ -129,8 +129,8 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
|
||||
|
||||
bool IsXlaGlobalJitOn(
|
||||
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
|
||||
XlaConfigRegistry::XlaGlobalJitLevel xla_global_jit_level =
|
||||
XlaConfigRegistry::GetGlobalJitLevel(jit_level_in_session_opts);
|
||||
xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
|
||||
xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
|
||||
// Return true only if XLA JIT is ON for both single-gpu and multi-gpu
|
||||
// graphs. This is a conservative approach that turns off the memory optimizer
|
||||
// when we are sure that all graphs will be processed by XLA JIT.
|
||||
|
@ -17,11 +17,37 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*static*/
|
||||
mutex XlaConfigRegistry::mu_(LINKER_INITIALIZED);
|
||||
namespace xla_config_registry {
|
||||
|
||||
/*static*/
|
||||
XlaConfigRegistry::GlobalJitLevelGetterTy
|
||||
XlaConfigRegistry::global_jit_level_getter_;
|
||||
namespace {
|
||||
struct GlobalJitLevelState {
|
||||
mutex mu;
|
||||
GlobalJitLevelGetterTy getter;
|
||||
};
|
||||
|
||||
GlobalJitLevelState* GetSingletonState() {
|
||||
static GlobalJitLevelState* state = new GlobalJitLevelState;
|
||||
return state;
|
||||
}
|
||||
} // anonymous
|
||||
|
||||
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) {
|
||||
GlobalJitLevelState* state = GetSingletonState();
|
||||
mutex_lock l(state->mu);
|
||||
CHECK(!state->getter);
|
||||
state->getter = std::move(getter);
|
||||
}
|
||||
|
||||
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||
GlobalJitLevelState* state = GetSingletonState();
|
||||
mutex_lock l(state->mu);
|
||||
if (!state->getter) {
|
||||
return {jit_level_in_session_opts, jit_level_in_session_opts};
|
||||
}
|
||||
return state->getter(jit_level_in_session_opts);
|
||||
}
|
||||
|
||||
} // namespace xla_config_registry
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,52 +23,39 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A registry class where XLA can register callbacks for Tensorflow to query
|
||||
// its status.
|
||||
class XlaConfigRegistry {
|
||||
public:
|
||||
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
|
||||
// single gpu and general (multi-gpu) graphs.
|
||||
struct XlaGlobalJitLevel {
|
||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||
OptimizerOptions::GlobalJitLevel general;
|
||||
};
|
||||
namespace xla_config_registry {
|
||||
|
||||
// Input is the jit_level in session config, and return value is the jit_level
|
||||
// from XLA, reflecting the effect of the environment variable flags.
|
||||
typedef std::function<XlaGlobalJitLevel(
|
||||
const OptimizerOptions::GlobalJitLevel&)>
|
||||
GlobalJitLevelGetterTy;
|
||||
|
||||
static void Register(XlaConfigRegistry::GlobalJitLevelGetterTy getter) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK(!global_jit_level_getter_);
|
||||
global_jit_level_getter_ = std::move(getter);
|
||||
}
|
||||
|
||||
static XlaGlobalJitLevel GetGlobalJitLevel(
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
|
||||
mutex_lock l(mu_);
|
||||
if (!global_jit_level_getter_) {
|
||||
return {jit_level_in_session_opts, jit_level_in_session_opts};
|
||||
}
|
||||
return global_jit_level_getter_(jit_level_in_session_opts);
|
||||
}
|
||||
|
||||
private:
|
||||
static mutex mu_;
|
||||
static GlobalJitLevelGetterTy global_jit_level_getter_ GUARDED_BY(mu_);
|
||||
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
|
||||
// single gpu and general (multi-gpu) graphs.
|
||||
struct XlaGlobalJitLevel {
|
||||
OptimizerOptions::GlobalJitLevel single_gpu;
|
||||
OptimizerOptions::GlobalJitLevel general;
|
||||
};
|
||||
|
||||
// Input is the jit_level in session config, and return value is the jit_level
|
||||
// from XLA, reflecting the effect of the environment variable flags.
|
||||
typedef std::function<XlaGlobalJitLevel(
|
||||
const OptimizerOptions::GlobalJitLevel&)>
|
||||
GlobalJitLevelGetterTy;
|
||||
|
||||
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter);
|
||||
|
||||
XlaGlobalJitLevel GetGlobalJitLevel(
|
||||
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts);
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER(getter) \
|
||||
REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
|
||||
REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
|
||||
|
||||
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
|
||||
static bool xla_config_registry_registration_##ctr = \
|
||||
(::tensorflow::XlaConfigRegistry::Register(getter), true)
|
||||
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
|
||||
static bool xla_config_registry_registration_##ctr = \
|
||||
(::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \
|
||||
getter), \
|
||||
true)
|
||||
|
||||
} // namespace xla_config_registry
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user