Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...

PiperOrigin-RevId: 328109152
Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
This commit is contained in:
Mehdi Amini 2020-08-24 03:17:29 -07:00 committed by TensorFlower Gardener
parent 227b34875f
commit 051ed1cbfd
10 changed files with 35 additions and 25 deletions

View File

@ -760,7 +760,7 @@ tf_cc_binary(
deps = [
":flatbuffer_translate_registeration",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
)
@ -812,7 +812,7 @@ tf_cc_binary(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -836,19 +836,18 @@ tf_cc_binary(
deps = [
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:StandardOps",
],
)
@ -875,7 +874,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
@ -909,7 +908,7 @@ cc_library(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",

View File

@ -30,12 +30,16 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
@ -98,7 +102,10 @@ int main(int argc, char** argv) {
// Load the MLIR module.
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
context.getDialectRegistry()
.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
mlir::StandardOpsDialect>();
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));

View File

@ -49,7 +49,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs;

View File

@ -122,7 +122,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.

View File

@ -62,6 +62,10 @@ class ImportQuantStatsPass
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<quant::QuantizationDialect>();
}
// Parses the serialized quant stats protobuf and initialize the internal
// data structure. This method must be called after the pass is created.
bool ParseQuantStats(const std::string &stats_str);

View File

@ -28,6 +28,7 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow:error_util",
@ -74,6 +75,6 @@ tf_cc_binary(
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
)

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@ -52,7 +53,7 @@ TfLiteStatus QuantizeModel(
}
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);

View File

@ -37,7 +37,6 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@ -84,6 +85,11 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
: unfold_batch_matmul_(unfold_batch_matmul) {}
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
TFL::TensorFlowLiteDialect>();
}
private:
bool unfold_batch_matmul_;
};

View File

@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
LstmUtilsTest() {}
void SetUp() override {
RegisterDialects();
context_ = std::make_unique<mlir::MLIRContext>();
context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
TensorFlowLiteDialect>();
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
fused_lstm_func_cifg_ =
@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test {
builder_.reset();
}
void RegisterDialects() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<TensorFlowLiteDialect>();
}
FuncOp fused_lstm_func_;
FuncOp fused_lstm_func_cifg_;
FuncOp fused_ln_lstm_func_;