Provides an environment variable via TF_XLA_FLAGS to turn on the MLIR bridge.

PiperOrigin-RevId: 316116765
Change-Id: I005c5b6712a4e7cdd72f4302caae93f58c5f840e
This commit is contained in:
A. Unique TensorFlower 2020-06-12 09:20:49 -07:00 committed by TensorFlower Gardener
parent 3efb46044d
commit 2b7fb42e3b
4 changed files with 32 additions and 6 deletions

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list;
absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names;
};
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -451,7 +451,6 @@ class Context(object):
self._inter_op_parallelism_threads = None
self._soft_device_placement = None
self._log_device_placement = None
self._enable_mlir_bridge = None
self._enable_mlir_graph_optimization = None
self._optimizer_experimental_options = {}
@ -927,8 +926,7 @@ class Context(object):
if self._log_device_placement is not None:
config.log_device_placement = self._log_device_placement
if self._enable_mlir_bridge is not None:
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
if self._enable_mlir_graph_optimization is not None:
config.experimental.enable_mlir_graph_optimization = (
self._enable_mlir_graph_optimization)
@ -1466,7 +1464,7 @@ class Context(object):
@property
def enable_mlir_bridge(self):
return self._enable_mlir_bridge
return pywrap_tfe.TF_IsMlirBridgeEnabled()
@property
def enable_mlir_graph_optimization(self):
@ -1474,7 +1472,7 @@ class Context(object):
@enable_mlir_bridge.setter
def enable_mlir_bridge(self, enabled):
self._enable_mlir_bridge = enabled
pywrap_tfe.TF_EnableMlirBridge(enabled)
self._thread_local_data.function_call_options = None
@enable_mlir_graph_optimization.setter

View File

@ -364,6 +364,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
});
// // TFE_Context Logic
m.def(
"TFE_NewContext",