Allow the TF_XLA_FLAGS tf_mlir_enable_mlir_bridge to still enable the experimental mlir bridge
for TensorFlow V1 even when the user overrides Session's default ConfigProto. In TensorFlow V1, when the user doesn't override Session's default ConfigProto, the ConfigProto uses the value from TF_XLA_FLAGS (via context.config). However, when the user passes in their own ConfigProto, enable_mlir_bridge is defaulted to false. PiperOrigin-RevId: 320661288 Change-Id: I5deacb4e12b6551a57c2496c35830ac3af25fe28
This commit is contained in:
parent
3adc7cf2c9
commit
c1921df0e2
@ -704,6 +704,7 @@ cc_library(
|
||||
srcs = ["mlir_bridge_pass.cc"],
|
||||
hdrs = ["mlir_bridge_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
// Skip function graphs as MlirBridgePass will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
|
||||
if (!IsEnabled(options.session_options->config)) {
|
||||
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge();
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
|
Loading…
Reference in New Issue
Block a user