Use new instead of static object to avoid destructor.

By avoiding destructor, it should work better with multi-threaded program.
This commit is contained in:
Trent Lo 2019-09-12 15:52:11 -07:00
parent a908920e93
commit 1d997791dc
4 changed files with 58 additions and 45 deletions

View File

@ -219,7 +219,7 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
namespace {
typedef XlaConfigRegistry::XlaGlobalJitLevel XlaGlobalJitLevel;
typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
XlaGlobalJitLevel GetXlaGlobalJitLevel(
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {

View File

@ -129,8 +129,8 @@ bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
bool IsXlaGlobalJitOn(
const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
XlaConfigRegistry::XlaGlobalJitLevel xla_global_jit_level =
XlaConfigRegistry::GetGlobalJitLevel(jit_level_in_session_opts);
xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
// Return true only if XLA JIT is ON for both single-gpu and multi-gpu
// graphs. This is a conservative approach that turns off the memory optimizer
// when we are sure that all graphs will be processed by XLA JIT.

View File

@ -17,11 +17,37 @@ limitations under the License.
namespace tensorflow {
/*static*/
mutex XlaConfigRegistry::mu_(LINKER_INITIALIZED);
namespace xla_config_registry {
/*static*/
XlaConfigRegistry::GlobalJitLevelGetterTy
XlaConfigRegistry::global_jit_level_getter_;
namespace {
struct GlobalJitLevelState {
mutex mu;
GlobalJitLevelGetterTy getter;
};
GlobalJitLevelState* GetSingletonState() {
static GlobalJitLevelState* state = new GlobalJitLevelState;
return state;
}
} // anonymous
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter) {
GlobalJitLevelState* state = GetSingletonState();
mutex_lock l(state->mu);
CHECK(!state->getter);
state->getter = std::move(getter);
}
XlaGlobalJitLevel GetGlobalJitLevel(
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
GlobalJitLevelState* state = GetSingletonState();
mutex_lock l(state->mu);
if (!state->getter) {
return {jit_level_in_session_opts, jit_level_in_session_opts};
}
return state->getter(jit_level_in_session_opts);
}
} // namespace xla_config_registry
} // namespace tensorflow

View File

@ -23,52 +23,39 @@ limitations under the License.
namespace tensorflow {
// A registry class where XLA can register callbacks for Tensorflow to query
// its status.
class XlaConfigRegistry {
public:
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
// single gpu and general (multi-gpu) graphs.
struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu;
OptimizerOptions::GlobalJitLevel general;
};
namespace xla_config_registry {
// Input is the jit_level in session config, and return value is the jit_level
// from XLA, reflecting the effect of the environment variable flags.
typedef std::function<XlaGlobalJitLevel(
const OptimizerOptions::GlobalJitLevel&)>
GlobalJitLevelGetterTy;
static void Register(XlaConfigRegistry::GlobalJitLevelGetterTy getter) {
mutex_lock l(mu_);
CHECK(!global_jit_level_getter_);
global_jit_level_getter_ = std::move(getter);
}
static XlaGlobalJitLevel GetGlobalJitLevel(
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
mutex_lock l(mu_);
if (!global_jit_level_getter_) {
return {jit_level_in_session_opts, jit_level_in_session_opts};
}
return global_jit_level_getter_(jit_level_in_session_opts);
}
private:
static mutex mu_;
static GlobalJitLevelGetterTy global_jit_level_getter_ GUARDED_BY(mu_);
// XlaGlobalJitLevel is used by XLA to expose its JIT level for processing
// single gpu and general (multi-gpu) graphs.
struct XlaGlobalJitLevel {
OptimizerOptions::GlobalJitLevel single_gpu;
OptimizerOptions::GlobalJitLevel general;
};
// Input is the jit_level in session config, and return value is the jit_level
// from XLA, reflecting the effect of the environment variable flags.
typedef std::function<XlaGlobalJitLevel(
const OptimizerOptions::GlobalJitLevel&)>
GlobalJitLevelGetterTy;
void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter);
XlaGlobalJitLevel GetGlobalJitLevel(
OptimizerOptions::GlobalJitLevel jit_level_in_session_opts);
#define REGISTER_XLA_CONFIG_GETTER(getter) \
REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter)
#define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \
REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter)
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
static bool xla_config_registry_registration_##ctr = \
(::tensorflow::XlaConfigRegistry::Register(getter), true)
#define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \
static bool xla_config_registry_registration_##ctr = \
(::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \
getter), \
true)
} // namespace xla_config_registry
} // namespace tensorflow