Support of shadow runs for MLIR TF bridge.
When MLIR bridge is only enabled by graph analysis, MLIR passes are executed in shadow mode and must not affect the original TF graph. Let TF graph to MLIR conversion, MLIR passes and MLIR to TF graph conversion run, but do not return failures in shadow mode, just capture stats in those cases. PiperOrigin-RevId: 341745436 Change-Id: I7a23c122955bf408f3757989b646a78bfa17a0e9
This commit is contained in:
parent
faf44b5391
commit
9e5339f2a9
@ -3,7 +3,11 @@
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -126,12 +130,14 @@ cc_library(
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -204,6 +210,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_graph_optimization_pass_test",
|
||||
srcs = ["mlir_graph_optimization_pass_test.cc"],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "litfiles",
|
||||
srcs = glob(["runlit*py"]),
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
@ -32,10 +33,19 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
auto* shadow_run_success = monitoring::Counter<0>::New(
|
||||
"/tensorflow/mlir/shadow_run_success", "Success count of MLIR shadow runs");
|
||||
|
||||
auto* shadow_run_failure = monitoring::Counter<2>::New(
|
||||
"/tensorflow/mlir/shadow_run_failure", "Failure count of MLIR shadow runs",
|
||||
"kind", "name");
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return {ref.data(), ref.size()};
|
||||
}
|
||||
@ -123,6 +133,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
<< "(registered " << registry_->passes().size()
|
||||
<< " passes)";
|
||||
|
||||
// For scenarios when the new bridge is enabled by analysis we need to make
|
||||
// sure that MLIR transformations are executed in a shadow mode.
|
||||
// In this case, no changes should be done to the original `graph`
|
||||
// and no failures propagated to the user.
|
||||
bool enabled_by_analysis =
|
||||
mlir_rollout_policy_(**graph, config_proto) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
|
||||
if (enabled_by_analysis) {
|
||||
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
@ -130,10 +151,21 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context);
|
||||
if (!module_ref_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
|
||||
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return module_ref_status.status();
|
||||
}
|
||||
|
||||
auto module_ref = std::move(module_ref_status.ValueOrDie());
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
@ -144,8 +176,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph));
|
||||
auto pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
if (!pass_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return pass_status;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||
@ -154,6 +195,25 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
|
||||
GraphExportConfig export_config;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
|
||||
// In case MLIR is enabled by analysis, verify that MLIR could be converted
|
||||
// back to TF graph. Original `graph` must stay the same.
|
||||
if (enabled_by_analysis) {
|
||||
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
|
||||
|
||||
auto mlir_to_graph_status =
|
||||
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
|
||||
&empty_flib, &control_ret_nodes);
|
||||
if (mlir_to_graph_status.ok()) {
|
||||
shadow_run_success->GetCell()->IncrementBy(1);
|
||||
} else {
|
||||
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
@ -60,10 +63,14 @@ class MlirOptimizationPassRegistry {
|
||||
// Returns the global registry of MLIR optimization passes.
|
||||
static MlirOptimizationPassRegistry& Global();
|
||||
|
||||
// Register optimization `pass` with the given `priority`.
|
||||
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||
passes_.insert({priority, std::move(pass)});
|
||||
}
|
||||
|
||||
// Free the memory allocated for all passes.
|
||||
void ClearPasses() { passes_.clear(); }
|
||||
|
||||
const Passes& passes() const { return passes_; }
|
||||
|
||||
private:
|
||||
@ -76,8 +83,11 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
public:
|
||||
explicit MlirFunctionOptimizationPass(
|
||||
const MlirOptimizationPassRegistry* registry =
|
||||
&MlirOptimizationPassRegistry::Global())
|
||||
: registry_(registry) {}
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
std::function<MlirBridgeRolloutPolicy(const Graph& graph,
|
||||
absl::optional<ConfigProto>)>
|
||||
mlir_rollout_policy = GetMlirBridgeRolloutPolicy)
|
||||
: registry_(registry), mlir_rollout_policy_(mlir_rollout_policy) {}
|
||||
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
@ -86,6 +96,9 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
|
||||
private:
|
||||
const MlirOptimizationPassRegistry* registry_;
|
||||
std::function<MlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph, absl::optional<tensorflow::ConfigProto>)>
|
||||
mlir_rollout_policy_;
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
121
tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc
Normal file
121
tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ::testing::_;
|
||||
using ::testing::NiceMock;
|
||||
using ::testing::Return;
|
||||
using ::testing::Test;
|
||||
|
||||
class MockMlirOptimizationPass : public MlirOptimizationPass {
|
||||
public:
|
||||
MOCK_METHOD(llvm::StringRef, name, (), (const, override));
|
||||
MOCK_METHOD(bool, IsEnabled,
|
||||
(const ConfigProto& config_proto, const Graph& graph),
|
||||
(const override));
|
||||
MOCK_METHOD(Status, Run,
|
||||
(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph),
|
||||
(override));
|
||||
};
|
||||
|
||||
class MlirGraphOptimizationPassTest : public Test {
|
||||
public:
|
||||
void Init(MlirBridgeRolloutPolicy rollout_policy, Status pass_run_result) {
|
||||
graph_ = std::make_unique<Graph>(OpRegistry::Global());
|
||||
|
||||
function_optimization_pass_ = MlirFunctionOptimizationPass(
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
[rollout_policy](const Graph& graph, absl::optional<ConfigProto>) {
|
||||
return rollout_policy;
|
||||
});
|
||||
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
EXPECT_CALL(*optimization_pass, IsEnabled(_, _))
|
||||
.WillRepeatedly(Return(true));
|
||||
EXPECT_CALL(*optimization_pass, Run(_, _, _))
|
||||
.WillOnce(Return(pass_run_result));
|
||||
MlirOptimizationPassRegistry::Global().Add(0, std::move(optimization_pass));
|
||||
|
||||
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
MlirOptimizationPassRegistry::Global().ClearPasses();
|
||||
}
|
||||
|
||||
ConfigProto config_proto_;
|
||||
MlirFunctionOptimizationPass function_optimization_pass_;
|
||||
DeviceSet device_set_;
|
||||
std::unique_ptr<Graph> graph_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_;
|
||||
std::vector<std::string> control_ret_node_names_;
|
||||
bool control_rets_updated_{false};
|
||||
};
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledByUser,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status::OK());
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include <gmock/gmock-generated-matchers.h>
|
||||
#include <gmock/gmock-matchers.h>
|
||||
#include <gmock/gmock-more-matchers.h>
|
||||
#include <gmock/gmock.h>
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
Loading…
Reference in New Issue
Block a user