Provides an environment variable via TF_XLA_FLAGS to turn on the MLIR bridge.
PiperOrigin-RevId: 316116765 Change-Id: I005c5b6712a4e7cdd72f4302caae93f58c5f840e
This commit is contained in:
parent
3efb46044d
commit
2b7fb42e3b
@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
|
|||||||
XlaDeviceFlags* device_flags;
|
XlaDeviceFlags* device_flags;
|
||||||
XlaOpsCommonFlags* ops_flags;
|
XlaOpsCommonFlags* ops_flags;
|
||||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||||
|
MlirCommonFlags* mlir_flags;
|
||||||
|
|
||||||
std::vector<Flag>* flag_list;
|
std::vector<Flag>* flag_list;
|
||||||
absl::once_flag flags_init;
|
absl::once_flag flags_init;
|
||||||
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
|
|||||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||||
jitter_flags->jitter_amount = 1e-5;
|
jitter_flags->jitter_amount = 1e-5;
|
||||||
|
|
||||||
|
mlir_flags = new MlirCommonFlags;
|
||||||
|
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||||
|
|
||||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||||
return true;
|
return true;
|
||||||
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
|
|||||||
Flag("tf_introduce_floating_point_jitter_amount",
|
Flag("tf_introduce_floating_point_jitter_amount",
|
||||||
&jitter_flags->jitter_amount,
|
&jitter_flags->jitter_amount,
|
||||||
"The amount of jitter to introduce. This amount is added to each "
|
"The amount of jitter to introduce. This amount is added to each "
|
||||||
"element in the tensors named in `tensor_names.")});
|
"element in the tensors named in `tensor_names."),
|
||||||
|
|
||||||
|
Flag("tf_mlir_enable_mlir_bridge",
|
||||||
|
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||||
|
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||||
|
|
||||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||||
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
|
|||||||
return *jitter_flags;
|
return *jitter_flags;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirCommonFlags* GetMlirCommonFlags() {
|
||||||
|
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return mlir_flags;
|
||||||
|
}
|
||||||
|
|
||||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||||
|
@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
|
|||||||
std::vector<string> tensor_names;
|
std::vector<string> tensor_names;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Flags for common MLIR configurations.
|
||||||
|
struct MlirCommonFlags {
|
||||||
|
bool tf_mlir_enable_mlir_bridge;
|
||||||
|
};
|
||||||
|
|
||||||
// Return a pointer to the DumpGraphFlags struct;
|
// Return a pointer to the DumpGraphFlags struct;
|
||||||
// repeated calls return the same pointer.
|
// repeated calls return the same pointer.
|
||||||
// This should be called only after Flags::Parse() has returned.
|
// This should be called only after Flags::Parse() has returned.
|
||||||
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
|||||||
const IntroduceFloatingPointJitterPassFlags&
|
const IntroduceFloatingPointJitterPassFlags&
|
||||||
GetIntroduceFloatingPointJitterPassFlags();
|
GetIntroduceFloatingPointJitterPassFlags();
|
||||||
|
|
||||||
|
MlirCommonFlags* GetMlirCommonFlags();
|
||||||
|
|
||||||
// Appends the flag definitions associated with
|
// Appends the flag definitions associated with
|
||||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||||
//
|
//
|
||||||
|
@ -451,7 +451,6 @@ class Context(object):
|
|||||||
self._inter_op_parallelism_threads = None
|
self._inter_op_parallelism_threads = None
|
||||||
self._soft_device_placement = None
|
self._soft_device_placement = None
|
||||||
self._log_device_placement = None
|
self._log_device_placement = None
|
||||||
self._enable_mlir_bridge = None
|
|
||||||
self._enable_mlir_graph_optimization = None
|
self._enable_mlir_graph_optimization = None
|
||||||
self._optimizer_experimental_options = {}
|
self._optimizer_experimental_options = {}
|
||||||
|
|
||||||
@ -927,8 +926,7 @@ class Context(object):
|
|||||||
if self._log_device_placement is not None:
|
if self._log_device_placement is not None:
|
||||||
config.log_device_placement = self._log_device_placement
|
config.log_device_placement = self._log_device_placement
|
||||||
|
|
||||||
if self._enable_mlir_bridge is not None:
|
config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||||
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
|
|
||||||
if self._enable_mlir_graph_optimization is not None:
|
if self._enable_mlir_graph_optimization is not None:
|
||||||
config.experimental.enable_mlir_graph_optimization = (
|
config.experimental.enable_mlir_graph_optimization = (
|
||||||
self._enable_mlir_graph_optimization)
|
self._enable_mlir_graph_optimization)
|
||||||
@ -1466,7 +1464,7 @@ class Context(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_mlir_bridge(self):
|
def enable_mlir_bridge(self):
|
||||||
return self._enable_mlir_bridge
|
return pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_mlir_graph_optimization(self):
|
def enable_mlir_graph_optimization(self):
|
||||||
@ -1474,7 +1472,7 @@ class Context(object):
|
|||||||
|
|
||||||
@enable_mlir_bridge.setter
|
@enable_mlir_bridge.setter
|
||||||
def enable_mlir_bridge(self, enabled):
|
def enable_mlir_bridge(self, enabled):
|
||||||
self._enable_mlir_bridge = enabled
|
pywrap_tfe.TF_EnableMlirBridge(enabled)
|
||||||
self._thread_local_data.function_call_options = None
|
self._thread_local_data.function_call_options = None
|
||||||
|
|
||||||
@enable_mlir_graph_optimization.setter
|
@enable_mlir_graph_optimization.setter
|
||||||
|
@ -364,6 +364,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||||
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
|
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
|
||||||
|
|
||||||
|
// MLIR Logic
|
||||||
|
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||||
|
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||||
|
});
|
||||||
|
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||||
|
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
|
||||||
|
});
|
||||||
|
|
||||||
// // TFE_Context Logic
|
// // TFE_Context Logic
|
||||||
m.def(
|
m.def(
|
||||||
"TFE_NewContext",
|
"TFE_NewContext",
|
||||||
|
Loading…
Reference in New Issue
Block a user