diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index bf6406a796b..cac72925dfd 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -704,6 +704,7 @@ cc_library( srcs = ["mlir_bridge_pass.cc"], hdrs = ["mlir_bridge_pass.h"], deps = [ + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index c398e5f129e..eefef26dc24 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return Status::OK(); - if (!options.session_options->config.experimental().enable_mlir_bridge()) { + if (!IsEnabled(options.session_options->config)) { VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index b7f8ef203f7..f7541e634d4 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" namespace tensorflow { @@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { llvm::StringRef name() const override { return "bridge"; } bool IsEnabled(const ConfigProto& config_proto) const override { - return config_proto.experimental().enable_mlir_bridge(); + return config_proto.experimental().enable_mlir_bridge() || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule