Provides an environment variable via TF_XLA_FLAGS to turn on the MLIR bridge.
PiperOrigin-RevId: 316116765 Change-Id: I005c5b6712a4e7cdd72f4302caae93f58c5f840e
This commit is contained in:
parent
3efb46044d
commit
2b7fb42e3b
@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
MlirCommonFlags* mlir_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
absl::once_flag flags_init;
|
||||
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names.")});
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
std::vector<string> tensor_names;
|
||||
};
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
@ -451,7 +451,6 @@ class Context(object):
|
||||
self._inter_op_parallelism_threads = None
|
||||
self._soft_device_placement = None
|
||||
self._log_device_placement = None
|
||||
self._enable_mlir_bridge = None
|
||||
self._enable_mlir_graph_optimization = None
|
||||
self._optimizer_experimental_options = {}
|
||||
|
||||
@ -927,8 +926,7 @@ class Context(object):
|
||||
if self._log_device_placement is not None:
|
||||
config.log_device_placement = self._log_device_placement
|
||||
|
||||
if self._enable_mlir_bridge is not None:
|
||||
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
|
||||
config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||
if self._enable_mlir_graph_optimization is not None:
|
||||
config.experimental.enable_mlir_graph_optimization = (
|
||||
self._enable_mlir_graph_optimization)
|
||||
@ -1466,7 +1464,7 @@ class Context(object):
|
||||
|
||||
@property
|
||||
def enable_mlir_bridge(self):
|
||||
return self._enable_mlir_bridge
|
||||
return pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||
|
||||
@property
|
||||
def enable_mlir_graph_optimization(self):
|
||||
@ -1474,7 +1472,7 @@ class Context(object):
|
||||
|
||||
@enable_mlir_bridge.setter
|
||||
def enable_mlir_bridge(self, enabled):
|
||||
self._enable_mlir_bridge = enabled
|
||||
pywrap_tfe.TF_EnableMlirBridge(enabled)
|
||||
self._thread_local_data.function_call_options = None
|
||||
|
||||
@enable_mlir_graph_optimization.setter
|
||||
|
@ -364,6 +364,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
|
||||
|
||||
// MLIR Logic
|
||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
});
|
||||
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
|
||||
});
|
||||
|
||||
// // TFE_Context Logic
|
||||
m.def(
|
||||
"TFE_NewContext",
|
||||
|
Loading…
Reference in New Issue
Block a user