diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 133d0e7b387..498ca4b59d9 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -120,7 +120,8 @@ Status MlirFunctionOptimizationPass::Run( // Skip conversion from Graph to MLIR if none of the passes are enabled. const bool is_enabled = llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool { - return pass_registration.pass->IsEnabled(config_proto, **graph); + return pass_registration.pass->IsEnabled(&device_set, config_proto, + **graph); }); if (!is_enabled) { @@ -251,7 +252,8 @@ Status MlirV1CompatGraphOptimizationPass::Run( const bool is_enabled = absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool { return pass_registration.pass->IsEnabled( - options.session_options->config, **options.graph); + options.device_set, options.session_options->config, + **options.graph); }); if (!is_enabled) { diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index 02a0aaf8629..b01f6b57e64 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -37,7 +37,12 @@ class MlirOptimizationPass { public: virtual ~MlirOptimizationPass() = default; virtual llvm::StringRef name() const = 0; - virtual bool IsEnabled(const ConfigProto& config_proto, + + // Returns true if the pass is enabled for the given graph with specified + // config. `device_set` can be nullptr if the devices information is not + // available or no device specific filtering is required. + virtual bool IsEnabled(const DeviceSet* device_set, + const ConfigProto& config_proto, const Graph& graph) const = 0; virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, @@ -114,7 +119,12 @@ class MlirV1CompatOptimizationPass { public: virtual ~MlirV1CompatOptimizationPass() = default; virtual llvm::StringRef name() const = 0; - virtual bool IsEnabled(const ConfigProto& config_proto, + + // Returns true if the pass is enabled for the given graph with specified + // config. `device_set` can be nullptr if the devices information is not + // available or no device specific filtering is required. + virtual bool IsEnabled(const DeviceSet* device_set, + const ConfigProto& config_proto, const Graph& graph) const = 0; virtual Status Run(const GraphOptimizationPassOptions& options, diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index adf33922fd3..74992f67532 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -28,11 +28,12 @@ using ::testing::Test; class MockMlirOptimizationPass : public MlirOptimizationPass { public: - // MOCK_METHOD does not work on Windows build, using MOCK_METHODX + // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX // instead. MOCK_CONST_METHOD0(name, llvm::StringRef()); - MOCK_CONST_METHOD2(IsEnabled, - bool(const ConfigProto& config_proto, const Graph& graph)); + MOCK_CONST_METHOD3(IsEnabled, + bool(const DeviceSet* device_set, + const ConfigProto& config_proto, const Graph& graph)); MOCK_METHOD3(Run, Status(const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph)); }; @@ -51,7 +52,7 @@ class MlirGraphOptimizationPassTest : public Test { auto optimization_pass = std::make_unique>(); - EXPECT_CALL(*optimization_pass, IsEnabled(_, _)) + EXPECT_CALL(*optimization_pass, IsEnabled(_, _, _)) .WillRepeatedly(Return(true)); EXPECT_CALL(*optimization_pass, Run(_, _, _)) .WillOnce(Return(pass_run_result)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h index 9272574b7a5..8179e7faab1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -27,7 +27,8 @@ class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass { public: llvm::StringRef name() const override { return "graph_optimization"; } - bool IsEnabled(const ::tensorflow::ConfigProto& config_proto, + bool IsEnabled(const ::tensorflow::DeviceSet* device_set, + const ::tensorflow::ConfigProto& config_proto, const tensorflow::Graph& graph) const override { return config_proto.experimental().enable_mlir_graph_optimization(); } diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 7b77700a52f..3a380e41912 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -233,6 +233,7 @@ cc_library( ":tfr_decompose_ctx", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:device_set", "//tensorflow/stream_executor/lib", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index 5b7248eb008..5567d6fe55b 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -29,7 +29,8 @@ auto* tf_core_op_expansion_graph_counter = namespace tfr { -bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto, +bool GraphDecomposePass::IsEnabled(const DeviceSet* device_set, + const ConfigProto& config_proto, const Graph& graph) const { const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str()); return tfr_lib_env_val != nullptr; @@ -37,7 +38,7 @@ bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto, Status GraphDecomposePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph) { - if (!IsEnabled(config_proto, graph)) { + if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) { LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decomposition" " library was not found"; return Status::OK(); diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index 37685f39779..28eb3d939df 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -33,7 +33,7 @@ class GraphDecomposePass : public MlirOptimizationPass { // Whether to run this pass. If this is enabled, the GraphDef will be imported // to MLIR even no tf composition file is found. - bool IsEnabled(const ConfigProto& config_proto, + bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph) const override; // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index d1b83cf6d56..e399eece0e1 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -876,6 +876,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:device_set", "@llvm-project//llvm:Support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 871e2a77a77..04546d4ceba 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -20,25 +20,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -namespace { - -// Checks if the module has any TPU devices in its device list. -bool HasTPUDevice(mlir::ModuleOp op) { - mlir::TF::RuntimeDevices devices; - if (failed(tensorflow::GetDevicesFromOp(op.getOperation(), &devices))) - return false; - - for (const auto& device : devices.device_names()) { - if (device.has_type && device.type == "TPU") return true; - } - return false; -} -} // namespace - namespace tensorflow { auto* mlir_bridge_gauge_v1 = monitoring::Gauge::New( @@ -48,6 +34,29 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); +namespace { + +// Checks if the module has any TPU devices in its device list. +bool HasTPUDevice(mlir::ModuleOp op) { + mlir::TF::RuntimeDevices devices; + if (failed(GetDevicesFromOp(op.getOperation(), &devices))) return false; + + for (const auto& device : devices.device_names()) { + if (device.has_type && device.type == "TPU") return true; + } + return false; +} + +bool HasTPUDevice(const DeviceSet& device_set) { + for (const Device* device : device_set.devices()) { + if (!device) continue; + const DeviceNameUtils::ParsedName& name = device->parsed_name(); + if (name.has_type && name.type == "TPU") return true; + } + return false; +} +} // namespace + // Analyzes the user requested policy as well as the contents of the graph and // determines whether the MLIR Bridge should be run. // @@ -65,6 +74,15 @@ bool IsMlirBridgePassEnabled(const Graph& graph, policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis); } +bool MlirBridgePass::IsEnabled(const DeviceSet* device_set, + const ConfigProto& config_proto, + const Graph& graph) const { + // Skip MLIR TPU Bridge if no TPU devices found. + if (device_set && !HasTPUDevice(*device_set)) return false; + + return IsMlirBridgePassEnabled(graph, config_proto); +} + // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated @@ -73,7 +91,9 @@ bool IsMlirBridgePassEnabled(const Graph& graph, // encapsulated graph to a particular device. Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph) { - if (!IsEnabled(config_proto, graph)) { + // Set device_set to nullptr here as the device specific checks are performed + // based on the devices in the module. + if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) { VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); @@ -92,12 +112,29 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto, return Status::OK(); } + +bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set, + const ConfigProto& config_proto, + const Graph& graph) const { + // Skip MLIR TPU Bridge if no TPU devices found. + if (device_set && !HasTPUDevice(*device_set)) return false; + + // Do not run the bridge if it's enabled by the graph analysis, + // only run if it's enabled by the user explicitly. + MlirBridgeRolloutPolicy policy = + GetMlirBridgeRolloutPolicy(graph, config_proto); + return policy == MlirBridgeRolloutPolicy::kEnabledByUser; +} + Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, mlir::ModuleOp module) { // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return Status::OK(); - if (!IsEnabled(options.session_options->config, **options.graph)) { + // Set device_set to nullptr here as the device specific checks are performed + // based on the devices in the module. + if (!IsEnabled(/*device_set=*/nullptr, options.session_options->config, + **options.graph)) { VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index ba3dea924de..cfee112fa69 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -32,10 +32,8 @@ class MlirBridgePass : public MlirOptimizationPass { public: llvm::StringRef name() const override { return "bridge"; } - bool IsEnabled(const ConfigProto& config_proto, - const Graph& graph) const override { - return IsMlirBridgePassEnabled(graph, config_proto); - } + bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph) const override; // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. @@ -50,14 +48,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { public: llvm::StringRef name() const override { return "bridge"; } - bool IsEnabled(const ConfigProto& config_proto, - const Graph& graph) const override { - // Do not run the bridge if it's enabled by the graph analysis, - // only run if it's enabled by the user explicitly. - MlirBridgeRolloutPolicy policy = - GetMlirBridgeRolloutPolicy(graph, config_proto); - return policy == MlirBridgeRolloutPolicy::kEnabledByUser; - } + bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph) const override; // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime.