From 2b7fb42e3b7112fc712edf05f29bbfd865a5515a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 12 Jun 2020 09:20:49 -0700 Subject: [PATCH] Provides an environment variable via TF_XLA_FLAGS to turn on the MLIR bridge. PiperOrigin-RevId: 316116765 Change-Id: I005c5b6712a4e7cdd72f4302caae93f58c5f840e --- tensorflow/compiler/jit/flags.cc | 15 ++++++++++++++- tensorflow/compiler/jit/flags.h | 7 +++++++ tensorflow/python/eager/context.py | 8 +++----- tensorflow/python/tfe_wrapper.cc | 8 ++++++++ 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 927baf4fe72..d1301a8c40f 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags; XlaDeviceFlags* device_flags; XlaOpsCommonFlags* ops_flags; IntroduceFloatingPointJitterPassFlags* jitter_flags; +MlirCommonFlags* mlir_flags; std::vector* flag_list; absl::once_flag flags_init; @@ -166,6 +167,9 @@ void AllocateAndParseFlags() { jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; + mlir_flags = new MlirCommonFlags; + mlir_flags->tf_mlir_enable_mlir_bridge = false; + auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); return true; @@ -211,7 +215,11 @@ void AllocateAndParseFlags() { Flag("tf_introduce_floating_point_jitter_amount", &jitter_flags->jitter_amount, "The amount of jitter to introduce. This amount is added to each " - "element in the tensors named in `tensor_names.")}); + "element in the tensors named in `tensor_names."), + + Flag("tf_mlir_enable_mlir_bridge", + &mlir_flags->tf_mlir_enable_mlir_bridge, + "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); @@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() { return *jitter_flags; } +MlirCommonFlags* GetMlirCommonFlags() { + absl::call_once(flags_init, &AllocateAndParseFlags); + return mlir_flags; +} + void AppendMarkForCompilationPassFlags(std::vector* flag_list) { absl::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index b77a009b49f..89e20d9f8ea 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags { std::vector tensor_names; }; +// Flags for common MLIR configurations. +struct MlirCommonFlags { + bool tf_mlir_enable_mlir_bridge; +}; + // Return a pointer to the DumpGraphFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. @@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); const IntroduceFloatingPointJitterPassFlags& GetIntroduceFloatingPointJitterPassFlags(); +MlirCommonFlags* GetMlirCommonFlags(); + // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 1c083ffe294..b01f0795c72 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -451,7 +451,6 @@ class Context(object): self._inter_op_parallelism_threads = None self._soft_device_placement = None self._log_device_placement = None - self._enable_mlir_bridge = None self._enable_mlir_graph_optimization = None self._optimizer_experimental_options = {} @@ -927,8 +926,7 @@ class Context(object): if self._log_device_placement is not None: config.log_device_placement = self._log_device_placement - if self._enable_mlir_bridge is not None: - config.experimental.enable_mlir_bridge = self._enable_mlir_bridge + config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled() if self._enable_mlir_graph_optimization is not None: config.experimental.enable_mlir_graph_optimization = ( self._enable_mlir_graph_optimization) @@ -1466,7 +1464,7 @@ class Context(object): @property def enable_mlir_bridge(self): - return self._enable_mlir_bridge + return pywrap_tfe.TF_IsMlirBridgeEnabled() @property def enable_mlir_graph_optimization(self): @@ -1474,7 +1472,7 @@ class Context(object): @enable_mlir_bridge.setter def enable_mlir_bridge(self, enabled): - self._enable_mlir_bridge = enabled + pywrap_tfe.TF_EnableMlirBridge(enabled) self._thread_local_data.function_call_options = None @enable_mlir_graph_optimization.setter diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 2901a63c829..00137f6f492 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -364,6 +364,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize); m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); }); + // MLIR Logic + m.def("TF_IsMlirBridgeEnabled", [] { + return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + }); + m.def("TF_EnableMlirBridge", [](bool enabled) { + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled; + }); + // // TFE_Context Logic m.def( "TFE_NewContext",