Let MlirFunctionOptimizationPass query individual registered passes whether they need to run in a shadow mode.
The full MlirFunctionOptimizationPass pipeline will be executed in shadow mode if all of the registered passes require shadow mode execution. In case of full shadow mode, TF graph round-tripping (TF graph->MLIR->TF graph) could only soft fail (no errors propagated back to the caller). If there is a mix of shadow enabled and enabled passes, a shadow pass failure will be reflected in a recorded metric, while failures in the enabled passes will be reported back to the caller. No changes must be done to the original TF graph in full shadow mode. Conversion from TF graph to MLIR module, execution of all registered passes, and conversion back to TF graph will not modify original TF graph. Removing explicit disablement of MLIR bridge for TFR test cases. PiperOrigin-RevId: 353922802 Change-Id: I20805a5d99670dd8942746282cb68e653f19f9be
This commit is contained in:
parent
8027470e1e
commit
14f5280e39
@ -117,33 +117,44 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) {
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(&device_set, config_proto,
|
||||
**graph);
|
||||
});
|
||||
// This tracks whether at least one pass is enabled, all passes are disabled,
|
||||
// or there is a mix of disabled and shadow enabled passes.
|
||||
MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled;
|
||||
|
||||
if (!is_enabled) {
|
||||
LOG_FIRST_N(INFO, 1)
|
||||
<< "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered " << registry_->passes().size() << ")";
|
||||
// Cache per pass state and reuse it during pass execution.
|
||||
std::vector<MlirOptimizationPassState> per_pass_state;
|
||||
per_pass_state.reserve(registry_->passes().size());
|
||||
|
||||
for (const auto& pass_registration : registry_->passes()) {
|
||||
MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState(
|
||||
&device_set, config_proto, **graph);
|
||||
per_pass_state.push_back(pass_state);
|
||||
if (pass_state == MlirOptimizationPassState::ShadowEnabled &&
|
||||
overall_state == MlirOptimizationPassState::Disabled) {
|
||||
overall_state = MlirOptimizationPassState::ShadowEnabled;
|
||||
} else if (pass_state == MlirOptimizationPassState::Enabled) {
|
||||
overall_state = MlirOptimizationPassState::Enabled;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/176852151): Remove this after dark launch completed.
|
||||
// Capture stats relevant to graph properties used in dark launch.
|
||||
GetMlirBridgeRolloutPolicy(**graph, config_proto, /*record_stats=*/true);
|
||||
|
||||
if (overall_state == MlirOptimizationPassState::Disabled) {
|
||||
LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled "
|
||||
<< "(registered " << registry_->passes().size() << ")";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
LOG_FIRST_N(INFO, 1) << "Running MLIR Graph Optimization Passes "
|
||||
<< "(registered " << registry_->passes().size()
|
||||
<< " passes)";
|
||||
|
||||
// For scenarios when the new bridge is enabled by analysis we need to make
|
||||
// sure that MLIR transformations are executed in a shadow mode.
|
||||
// In this case, no changes should be done to the original `graph`
|
||||
// and no failures propagated to the user.
|
||||
bool enabled_by_analysis =
|
||||
mlir_rollout_policy_(**graph, config_proto, /*record_stats=*/true) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
|
||||
if (enabled_by_analysis) {
|
||||
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
|
||||
if (overall_state == MlirOptimizationPassState::Enabled) {
|
||||
LOG_FIRST_N(INFO, 1) << "At least one MLIR Graph Optimization Pass enabled"
|
||||
<< "(registered " << registry_->passes().size()
|
||||
<< " passes)";
|
||||
} else if (overall_state == MlirOptimizationPassState::ShadowEnabled) {
|
||||
LOG_FIRST_N(INFO, 1)
|
||||
<< "All MLIR Graph Optimization Passes are shadow enabled"
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
@ -163,19 +174,21 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context);
|
||||
if (!module_ref_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
|
||||
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
// If at least one pass is enabled, return failure to the caller
|
||||
// immediately.
|
||||
if (overall_state == MlirOptimizationPassState::Enabled) {
|
||||
return module_ref_status.status();
|
||||
}
|
||||
|
||||
return module_ref_status.status();
|
||||
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
|
||||
// Do not fail, just keep the original TF graph unchanged in shadow mode.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto module_ref = std::move(module_ref_status.ValueOrDie());
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
int per_pass_state_index = 0;
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
llvm::StringRef name = pass_registration.pass->name();
|
||||
VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name);
|
||||
@ -184,16 +197,36 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
auto pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
if (!pass_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
Status pass_status = Status::OK();
|
||||
auto pass_state = per_pass_state[per_pass_state_index++];
|
||||
// There will not be MLIR module conversion back to the TF graph at the
|
||||
// very end if overall state is ShadowEnabled.
|
||||
// Avoid making MLIR module copies in this case.
|
||||
if (pass_state == MlirOptimizationPassState::Enabled ||
|
||||
(pass_state == MlirOptimizationPassState::ShadowEnabled &&
|
||||
overall_state == MlirOptimizationPassState::ShadowEnabled)) {
|
||||
pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
} else if (per_pass_state[per_pass_state_index] ==
|
||||
MlirOptimizationPassState::ShadowEnabled) {
|
||||
// Make sure that the pass does not modify MLIR module if it's shadow
|
||||
// enabled.
|
||||
auto module_ref_clone = module_ref->clone();
|
||||
pass_status =
|
||||
pass_registration.pass->Run(config_proto, module_ref_clone, **graph);
|
||||
module_ref_clone->destroy();
|
||||
}
|
||||
|
||||
return pass_status;
|
||||
if (!pass_status.ok()) {
|
||||
// If pass failed and pass is:
|
||||
// ShadowEnabled - only collect metrics, do not propagate
|
||||
// error to the caller.
|
||||
// Enabled - return error back to the caller.
|
||||
if (pass_state == MlirOptimizationPassState::ShadowEnabled) {
|
||||
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
|
||||
} else if (pass_state == MlirOptimizationPassState::Enabled) {
|
||||
return pass_status;
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
@ -204,9 +237,9 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
GraphExportConfig export_config;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
|
||||
// In case MLIR is enabled by analysis, verify that MLIR could be converted
|
||||
// back to TF graph. Original `graph` must stay the same.
|
||||
if (enabled_by_analysis) {
|
||||
// All passes are shadow enabled. Just convert MLIR module back to
|
||||
// the dummy graph and record success/failure stats.
|
||||
if (overall_state == MlirOptimizationPassState::ShadowEnabled) {
|
||||
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
|
||||
|
||||
@ -222,6 +255,8 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Some or all passes are enabled. Convert MLIR module and return back
|
||||
// resulted graph.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
|
@ -29,6 +29,8 @@ namespace tensorflow {
|
||||
// MLIR passes running on Tensorflow function graphs (Tensorflow V2).
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
enum class MlirOptimizationPassState { Disabled, Enabled, ShadowEnabled };
|
||||
|
||||
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
|
||||
// passes are running only for function graphs built by Tensorflow V2 and
|
||||
// instantiated by the process_function_library_runtime (see
|
||||
@ -38,12 +40,19 @@ class MlirOptimizationPass {
|
||||
virtual ~MlirOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
|
||||
// Returns true if the pass is enabled for the given graph with specified
|
||||
// config. `device_set` can be nullptr if the devices information is not
|
||||
// Returns an enum value:
|
||||
// Enabled if the pass is enabled for the given graph with specified config.
|
||||
// Disabled if the pass is disabled.
|
||||
// ShadowEnabled if the pass needs to be executed in shadow mode.
|
||||
//
|
||||
// When the pass is ShadowEnabled, the pass is executed for metrics collection
|
||||
// and reporting purposes only, but none of the changes it makes to the MLIR
|
||||
// module will be committed.
|
||||
// `device_set` can be nullptr if the devices information is not
|
||||
// available or no device specific filtering is required.
|
||||
virtual bool IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
virtual MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
|
||||
virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) = 0;
|
||||
@ -88,12 +97,23 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
public:
|
||||
explicit MlirFunctionOptimizationPass(
|
||||
const MlirOptimizationPassRegistry* registry =
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
std::function<MlirBridgeRolloutPolicy(
|
||||
const Graph&, absl::optional<ConfigProto>, bool record_stats)>
|
||||
mlir_rollout_policy = GetMlirBridgeRolloutPolicy)
|
||||
: registry_(registry), mlir_rollout_policy_(mlir_rollout_policy) {}
|
||||
&MlirOptimizationPassRegistry::Global())
|
||||
: registry_(registry) {}
|
||||
|
||||
// Executes all of the underlying registered MlirOptimizationPasses.
|
||||
//
|
||||
// The MlirFunctionOptimizationPass will be executed in fully shadow mode if
|
||||
// all of the underlying registered MlirOptimizationPasses are ShadowEnabled.
|
||||
// In this case, no changes should be done to the original TF graph and no
|
||||
// failures propagated back to the user. Failures during the conversion
|
||||
// of TF graph to MLIR module and back will be treated as a soft
|
||||
// failures, e.g., relevant stats will be recorded and no error returned
|
||||
// back to the caller.
|
||||
//
|
||||
// In case some of the passes are shadow enabled while others are enabled,
|
||||
// failures in the enabled passes will be treated as real errors and
|
||||
// propagated back to the caller. Failure during the shadow pass execution
|
||||
// is a soft failure.
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
@ -101,9 +121,6 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
|
||||
private:
|
||||
const MlirOptimizationPassRegistry* registry_;
|
||||
std::function<MlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph&, absl::optional<tensorflow::ConfigProto>, bool)>
|
||||
mlir_rollout_policy_;
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
@ -31,32 +31,32 @@ class MockMlirOptimizationPass : public MlirOptimizationPass {
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
|
||||
// instead.
|
||||
MOCK_CONST_METHOD0(name, llvm::StringRef());
|
||||
MOCK_CONST_METHOD3(IsEnabled,
|
||||
bool(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto, const Graph& graph));
|
||||
MOCK_CONST_METHOD3(GetPassState,
|
||||
MlirOptimizationPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph));
|
||||
MOCK_METHOD3(Run, Status(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph));
|
||||
};
|
||||
|
||||
class MlirGraphOptimizationPassTest : public Test {
|
||||
public:
|
||||
void Init(MlirBridgeRolloutPolicy rollout_policy, Status pass_run_result) {
|
||||
void Init(Status pass_run_result,
|
||||
const std::vector<MlirOptimizationPassState>& pass_states) {
|
||||
graph_ = std::make_unique<Graph>(OpRegistry::Global());
|
||||
|
||||
function_optimization_pass_ = MlirFunctionOptimizationPass(
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
[rollout_policy](const Graph&, absl::optional<ConfigProto>, bool) {
|
||||
return rollout_policy;
|
||||
});
|
||||
int pass_priority = 0;
|
||||
for (const MlirOptimizationPassState& pass_state : pass_states) {
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
EXPECT_CALL(*optimization_pass, IsEnabled(_, _, _))
|
||||
.WillRepeatedly(Return(true));
|
||||
EXPECT_CALL(*optimization_pass, Run(_, _, _))
|
||||
.WillOnce(Return(pass_run_result));
|
||||
MlirOptimizationPassRegistry::Global().Add(0, std::move(optimization_pass));
|
||||
ON_CALL(*optimization_pass, GetPassState(_, _, _))
|
||||
.WillByDefault(Return(pass_state));
|
||||
ON_CALL(*optimization_pass, Run(_, _, _))
|
||||
.WillByDefault(Return(pass_run_result));
|
||||
MlirOptimizationPassRegistry::Global().Add(pass_priority++,
|
||||
std::move(optimization_pass));
|
||||
}
|
||||
|
||||
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
|
||||
}
|
||||
@ -65,6 +65,17 @@ class MlirGraphOptimizationPassTest : public Test {
|
||||
MlirOptimizationPassRegistry::Global().ClearPasses();
|
||||
}
|
||||
|
||||
void verifyGraphUnchanged(const GraphDef& original_graph_def) {
|
||||
// Proto matchers might be unavailable in the OSS.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
}
|
||||
|
||||
ConfigProto config_proto_;
|
||||
MlirFunctionOptimizationPass function_optimization_pass_;
|
||||
DeviceSet device_set_;
|
||||
@ -75,8 +86,8 @@ class MlirGraphOptimizationPassTest : public Test {
|
||||
};
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledByUser,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
Init(Status(error::Code::ABORTED, "aborted"),
|
||||
{MlirOptimizationPassState::Enabled});
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
@ -85,20 +96,13 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
verifyGraphUnchanged(original_graph_def);
|
||||
}
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
Init(Status(error::Code::ABORTED, "aborted"),
|
||||
{MlirOptimizationPassState::ShadowEnabled,
|
||||
MlirOptimizationPassState::ShadowEnabled});
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
@ -107,15 +111,36 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status::OK());
|
||||
verifyGraphUnchanged(original_graph_def);
|
||||
}
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) {
|
||||
Init(Status::OK(), {MlirOptimizationPassState::Disabled,
|
||||
MlirOptimizationPassState::ShadowEnabled});
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status::OK());
|
||||
verifyGraphUnchanged(original_graph_def);
|
||||
}
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest,
|
||||
OptimizationPassFailsMixShadowAndEnabled) {
|
||||
Init(Status(error::Code::ABORTED, "aborted"),
|
||||
{MlirOptimizationPassState::Disabled, MlirOptimizationPassState::Enabled,
|
||||
MlirOptimizationPassState::ShadowEnabled});
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
}
|
||||
|
||||
TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriority) {
|
||||
|
@ -27,10 +27,13 @@ class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "graph_optimization"; }
|
||||
|
||||
bool IsEnabled(const ::tensorflow::DeviceSet* device_set,
|
||||
const ::tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::Graph& graph) const override {
|
||||
return config_proto.experimental().enable_mlir_graph_optimization();
|
||||
::tensorflow::MlirOptimizationPassState GetPassState(
|
||||
const ::tensorflow::DeviceSet* device_set,
|
||||
const ::tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::Graph& graph) const override {
|
||||
return config_proto.experimental().enable_mlir_graph_optimization()
|
||||
? tensorflow::MlirOptimizationPassState::Enabled
|
||||
: tensorflow::MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
::tensorflow::Status Run(const ::tensorflow::ConfigProto& config_proto,
|
||||
|
@ -254,7 +254,6 @@ tf_py_test(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tfr/resources:composite_ops",
|
||||
"//tensorflow/python:is_mlir_bridge_test_false",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
@ -39,7 +39,6 @@ tf_py_test(
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/compiler/mlir/tfr:test_utils",
|
||||
"//tensorflow/python:is_mlir_bridge_test_false",
|
||||
"//tensorflow/python:test_ops",
|
||||
],
|
||||
)
|
||||
|
@ -42,7 +42,6 @@ tf_py_test(
|
||||
":mnist_ops_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/compiler/mlir/tfr:test_utils",
|
||||
"//tensorflow/python:is_mlir_bridge_test_false",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -41,6 +41,5 @@ tf_py_test(
|
||||
":pad_ops_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/compiler/mlir/tfr:test_utils",
|
||||
"//tensorflow/python:is_mlir_bridge_test_false",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
@ -29,16 +30,18 @@ auto* tf_core_op_expansion_graph_counter =
|
||||
|
||||
namespace tfr {
|
||||
|
||||
bool GraphDecomposePass::IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
MlirOptimizationPassState GraphDecomposePass::GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str());
|
||||
return tfr_lib_env_val != nullptr;
|
||||
return tfr_lib_env_val != nullptr ? MlirOptimizationPassState::Enabled
|
||||
: MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
MlirOptimizationPassState::Disabled) {
|
||||
LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decomposition"
|
||||
" library was not found";
|
||||
return Status::OK();
|
||||
|
@ -33,8 +33,9 @@ class GraphDecomposePass : public MlirOptimizationPass {
|
||||
|
||||
// Whether to run this pass. If this is enabled, the GraphDef will be imported
|
||||
// to MLIR even no tf composition file is found.
|
||||
bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
::tensorflow::MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
|
@ -66,21 +66,23 @@ bool HasTPUDevice(const DeviceSet& device_set) {
|
||||
//
|
||||
// The config_proto param is a required input for all TF1 graphs but it is
|
||||
// redundant for TF2 graphs.
|
||||
bool IsMlirBridgePassEnabled(const Graph& graph,
|
||||
const absl::optional<ConfigProto>& config_proto) {
|
||||
MlirOptimizationPassState MlirBridgePass::GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set)) {
|
||||
return MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||
return (policy == MlirBridgeRolloutPolicy::kEnabledByUser ||
|
||||
policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis);
|
||||
}
|
||||
|
||||
bool MlirBridgePass::IsEnabled(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set)) return false;
|
||||
|
||||
return IsMlirBridgePassEnabled(graph, config_proto);
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
return MlirOptimizationPassState::Enabled;
|
||||
} else if (policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis) {
|
||||
return MlirOptimizationPassState::ShadowEnabled;
|
||||
} else {
|
||||
return MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
}
|
||||
|
||||
// This runs the first phase of the "bridge", transforming the graph in a form
|
||||
@ -93,7 +95,8 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
// Set device_set to nullptr here as the device specific checks are performed
|
||||
// based on the devices in the module.
|
||||
if (!IsEnabled(/*device_set=*/nullptr, config_proto, graph)) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
MlirOptimizationPassState::Disabled) {
|
||||
VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled";
|
||||
mlir_bridge_gauge_v2->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool IsMlirBridgePassEnabled(const Graph& graph,
|
||||
const absl::optional<ConfigProto>& config_proto);
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow Function Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
@ -32,8 +30,9 @@ class MlirBridgePass : public MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
MlirOptimizationPassState GetPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
|
Loading…
Reference in New Issue
Block a user