Provides an environment variable via TF_XLA_FLAGS to turn on the MLIR bridge.

PiperOrigin-RevId: 316116765
Change-Id: I005c5b6712a4e7cdd72f4302caae93f58c5f840e
This commit is contained in:
A. Unique TensorFlower 2020-06-12 09:20:49 -07:00 committed by TensorFlower Gardener
parent 3efb46044d
commit 2b7fb42e3b
4 changed files with 32 additions and 6 deletions

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags; XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags; XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags; IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list; std::vector<Flag>* flag_list;
absl::once_flag flags_init; absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5; jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) { auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ','); jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true; return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount", Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount, &jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each " "The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")}); "element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags; return *jitter_flags;
} }
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names; std::vector<string> tensor_names;
}; };
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct; // Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer. // repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned. // This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags& const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags(); GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with // Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
// //

View File

@ -451,7 +451,6 @@ class Context(object):
self._inter_op_parallelism_threads = None self._inter_op_parallelism_threads = None
self._soft_device_placement = None self._soft_device_placement = None
self._log_device_placement = None self._log_device_placement = None
self._enable_mlir_bridge = None
self._enable_mlir_graph_optimization = None self._enable_mlir_graph_optimization = None
self._optimizer_experimental_options = {} self._optimizer_experimental_options = {}
@ -927,8 +926,7 @@ class Context(object):
if self._log_device_placement is not None: if self._log_device_placement is not None:
config.log_device_placement = self._log_device_placement config.log_device_placement = self._log_device_placement
if self._enable_mlir_bridge is not None: config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
if self._enable_mlir_graph_optimization is not None: if self._enable_mlir_graph_optimization is not None:
config.experimental.enable_mlir_graph_optimization = ( config.experimental.enable_mlir_graph_optimization = (
self._enable_mlir_graph_optimization) self._enable_mlir_graph_optimization)
@ -1466,7 +1464,7 @@ class Context(object):
@property @property
def enable_mlir_bridge(self): def enable_mlir_bridge(self):
return self._enable_mlir_bridge return pywrap_tfe.TF_IsMlirBridgeEnabled()
@property @property
def enable_mlir_graph_optimization(self): def enable_mlir_graph_optimization(self):
@ -1474,7 +1472,7 @@ class Context(object):
@enable_mlir_bridge.setter @enable_mlir_bridge.setter
def enable_mlir_bridge(self, enabled): def enable_mlir_bridge(self, enabled):
self._enable_mlir_bridge = enabled pywrap_tfe.TF_EnableMlirBridge(enabled)
self._thread_local_data.function_call_options = None self._thread_local_data.function_call_options = None
@enable_mlir_graph_optimization.setter @enable_mlir_graph_optimization.setter

View File

@ -364,6 +364,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize); m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); }); m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
});
// // TFE_Context Logic // // TFE_Context Logic
m.def( m.def(
"TFE_NewContext", "TFE_NewContext",