Allow the TF_XLA_FLAGS tf_mlir_enable_mlir_bridge to still enable the experimental mlir bridge

for TensorFlow V1 even when the user overrides Session's default ConfigProto.

In TensorFlow V1, when the user doesn't override Session's default ConfigProto, the ConfigProto
uses the value from TF_XLA_FLAGS (via context.config). However, when the user passes in their
own ConfigProto, enable_mlir_bridge is defaulted to false.

PiperOrigin-RevId: 320661288
Change-Id: I5deacb4e12b6551a57c2496c35830ac3af25fe28
This commit is contained in:
Marissa Ikonomidis 2020-07-10 13:09:53 -07:00 committed by TensorFlower Gardener
parent 3adc7cf2c9
commit c1921df0e2
3 changed files with 5 additions and 2 deletions

View File

@ -704,6 +704,7 @@ cc_library(
srcs = ["mlir_bridge_pass.cc"],
hdrs = ["mlir_bridge_pass.h"],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:core_cpu",

View File

@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// Skip function graphs as MlirBridgePass will be used instead.
if (options.is_function_graph) return Status::OK();
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
if (!IsEnabled(options.session_options->config)) {
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
mlir_bridge_gauge_v1->GetCell()->Set(false);
return Status::OK();

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#include "llvm/ADT/StringRef.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
namespace tensorflow {
@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
llvm::StringRef name() const override { return "bridge"; }
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge();
return config_proto.experimental().enable_mlir_bridge() ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule