Do not round trip to MLIR in the MLIR bridge if no TPUs are found
Currently, the filtering in the `IsEnabled` method doesn't inspect the devices but at the time of execution MLIR passes are skipped if there are no TPUs. This results in the conversion to MLIR and then back to GraphDef. This roundtrip does functionalization of control flow v1 and some other mutations so it is not idempotent. Doing checks in the `IsEnabled` method will skip the unnecessary roundtrip. PiperOrigin-RevId: 351665309 Change-Id: I94165e99dd1c0166cd5979e6cd32db4ab7b2637f
This commit is contained in:
parent
e8f5a369d7
commit
8f556acfa1
tensorflow/compiler
@ -120,7 +120,8 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(config_proto, **graph);
|
||||
return pass_registration.pass->IsEnabled(&device_set, config_proto,
|
||||
**graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
@ -251,7 +252,8 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
const bool is_enabled =
|
||||
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(
|
||||
options.session_options->config, **options.graph);
|
||||
options.device_set, options.session_options->config,
|
||||
**options.graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
|
@ -37,7 +37,12 @@ class MlirOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto,
|
||||
|
||||
// Returns true if the pass is enabled for the given graph with specified
|
||||
// config. `device_set` can be nullptr if the devices information is not
|
||||
// available or no device specific filtering is required.
|
||||
virtual bool IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
|
||||
virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
@ -114,7 +119,12 @@ class MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirV1CompatOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto,
|
||||
|
||||
// Returns true if the pass is enabled for the given graph with specified
|
||||
// config. `device_set` can be nullptr if the devices information is not
|
||||
// available or no device specific filtering is required.
|
||||
virtual bool IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
|
||||
virtual Status Run(const GraphOptimizationPassOptions& options,
|
||||
|
@ -28,11 +28,12 @@ using ::testing::Test;
|
||||
|
||||
class MockMlirOptimizationPass : public MlirOptimizationPass {
|
||||
public:
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_METHODX
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
|
||||
// instead.
|
||||
MOCK_CONST_METHOD0(name, llvm::StringRef());
|
||||
MOCK_CONST_METHOD2(IsEnabled,
|
||||
bool(const ConfigProto& config_proto, const Graph& graph));
|
||||
MOCK_CONST_METHOD3(IsEnabled,
|
||||
bool(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto, const Graph& graph));
|
||||
MOCK_METHOD3(Run, Status(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph));
|
||||
};
|
||||
@ -51,7 +52,7 @@ class MlirGraphOptimizationPassTest : public Test {
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
EXPECT_CALL(*optimization_pass, IsEnabled(_, _))
|
||||
EXPECT_CALL(*optimization_pass, IsEnabled(_, _, _))
|
||||
.WillRepeatedly(Return(true));
|
||||
EXPECT_CALL(*optimization_pass, Run(_, _, _))
|
||||
.WillOnce(Return(pass_run_result));
|
||||
|
@ -27,7 +27,8 @@ class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "graph_optimization"; }
|
||||
|
||||
bool IsEnabled(const ::tensorflow::ConfigProto& config_proto,
|
||||
bool IsEnabled(const ::tensorflow::DeviceSet* device_set,
|
||||
const ::tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::Graph& graph) const override {
|
||||
return config_proto.experimental().enable_mlir_graph_optimization();
|
||||
}
|
||||
|
@ -233,6 +233,7 @@ cc_library(
|
||||
":tfr_decompose_ctx",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:device_set",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
|
@ -29,7 +29,8 @@ auto* tf_core_op_expansion_graph_counter =
|
||||
|
||||
namespace tfr {
|
||||
|
||||
bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto,
|
||||
bool GraphDecomposePass::IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str());
|
||||
return tfr_lib_env_val != nullptr;
|
||||
@ -37,7 +38,7 @@ bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto,
|
||||
|
||||
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
if (!IsEnabled(config_proto, graph)) {
|
||||
if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) {
|
||||
LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decomposition"
|
||||
" library was not found";
|
||||
return Status::OK();
|
||||
|
@ -33,7 +33,7 @@ class GraphDecomposePass : public MlirOptimizationPass {
|
||||
|
||||
// Whether to run this pass. If this is enabled, the GraphDef will be imported
|
||||
// to MLIR even no tf composition file is found.
|
||||
bool IsEnabled(const ConfigProto& config_proto,
|
||||
bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
|
@ -876,6 +876,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:device_set",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -20,25 +20,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// Checks if the module has any TPU devices in its device list.
|
||||
bool HasTPUDevice(mlir::ModuleOp op) {
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(tensorflow::GetDevicesFromOp(op.getOperation(), &devices)))
|
||||
return false;
|
||||
|
||||
for (const auto& device : devices.device_names()) {
|
||||
if (device.has_type && device.type == "TPU") return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
auto* mlir_bridge_gauge_v1 = monitoring::Gauge<bool, 0>::New(
|
||||
@ -48,6 +34,29 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge<bool, 0>::New(
|
||||
"/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2",
|
||||
"Tracks usage of the MLIR-based TF2XLA bridge among TF2 models");
|
||||
|
||||
namespace {
|
||||
|
||||
// Checks if the module has any TPU devices in its device list.
|
||||
bool HasTPUDevice(mlir::ModuleOp op) {
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(GetDevicesFromOp(op.getOperation(), &devices))) return false;
|
||||
|
||||
for (const auto& device : devices.device_names()) {
|
||||
if (device.has_type && device.type == "TPU") return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasTPUDevice(const DeviceSet& device_set) {
|
||||
for (const Device* device : device_set.devices()) {
|
||||
if (!device) continue;
|
||||
const DeviceNameUtils::ParsedName& name = device->parsed_name();
|
||||
if (name.has_type && name.type == "TPU") return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Analyzes the user requested policy as well as the contents of the graph and
|
||||
// determines whether the MLIR Bridge should be run.
|
||||
//
|
||||
@ -65,6 +74,15 @@ bool IsMlirBridgePassEnabled(const Graph& graph,
|
||||
policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis);
|
||||
}
|
||||
|
||||
bool MlirBridgePass::IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set)) return false;
|
||||
|
||||
return IsMlirBridgePassEnabled(graph, config_proto);
|
||||
}
|
||||
|
||||
// This runs the first phase of the "bridge", transforming the graph in a form
|
||||
// that can be executed with delegation of some computations to an accelerator.
|
||||
// This builds on the model of XLA where a subset of the graph is encapsulated
|
||||
@ -73,7 +91,9 @@ bool IsMlirBridgePassEnabled(const Graph& graph,
|
||||
// encapsulated graph to a particular device.
|
||||
Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
if (!IsEnabled(config_proto, graph)) {
|
||||
// Set device_set to nullptr here as the device specific checks are performed
|
||||
// based on the devices in the module.
|
||||
if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) {
|
||||
VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled";
|
||||
mlir_bridge_gauge_v2->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
@ -92,12 +112,29 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set)) return false;
|
||||
|
||||
// Do not run the bridge if it's enabled by the graph analysis,
|
||||
// only run if it's enabled by the user explicitly.
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
}
|
||||
|
||||
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) {
|
||||
// Skip function graphs as MlirBridgePass will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
if (!IsEnabled(options.session_options->config, **options.graph)) {
|
||||
// Set device_set to nullptr here as the device specific checks are performed
|
||||
// based on the devices in the module.
|
||||
if (!IsEnabled(/*device_set=*/nullptr, options.session_options->config,
|
||||
**options.graph)) {
|
||||
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
|
@ -32,10 +32,8 @@ class MlirBridgePass : public MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto,
|
||||
const Graph& graph) const override {
|
||||
return IsMlirBridgePassEnabled(graph, config_proto);
|
||||
}
|
||||
bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
@ -50,14 +48,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto,
|
||||
const Graph& graph) const override {
|
||||
// Do not run the bridge if it's enabled by the graph analysis,
|
||||
// only run if it's enabled by the user explicitly.
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
}
|
||||
bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
|
Loading…
Reference in New Issue
Block a user