Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...

PiperOrigin-RevId: 328109152
Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
This commit is contained in:
Mehdi Amini 2020-08-24 03:17:29 -07:00 committed by TensorFlower Gardener
parent 227b34875f
commit 051ed1cbfd
10 changed files with 35 additions and 25 deletions

View File

@ -760,7 +760,7 @@ tf_cc_binary(
deps = [ deps = [
":flatbuffer_translate_registeration", ":flatbuffer_translate_registeration",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
], ],
) )
@ -812,7 +812,7 @@ tf_cc_binary(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -836,19 +836,18 @@ tf_cc_binary(
deps = [ deps = [
":flatbuffer_translate_lib", ":flatbuffer_translate_lib",
":flatbuffer_translate_registeration", ":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings", ":tensorflow_lite",
"@llvm-project//llvm:Support", "//tensorflow/compiler/mlir/tensorflow",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:StandardOps",
], ],
) )
@ -875,7 +874,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base", "//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
@ -909,7 +908,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",

View File

@ -30,12 +30,16 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h" #include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate.h"
@ -98,7 +102,10 @@ int main(int argc, char** argv) {
// Load the MLIR module. // Load the MLIR module.
mlir::MLIRContext context; mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects(); context.getDialectRegistry()
.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
mlir::StandardOpsDialect>();
llvm::SourceMgr source_mgr; llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context)); mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));

View File

@ -49,7 +49,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const GraphDef& input, const GraphDef& input,
string* result) { string* result) {
mlir::MLIRContext context; mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig specs; GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs; mlir::TFL::QuantizationSpecs quant_specs;

View File

@ -122,7 +122,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) { string* result) {
mlir::MLIRContext context; mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::TFL::QuantizationSpecs quant_specs; mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays. // Parse input arrays.

View File

@ -62,6 +62,10 @@ class ImportQuantStatsPass
void runOnFunction() override; void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<quant::QuantizationDialect>();
}
// Parses the serialized quant stats protobuf and initialize the internal // Parses the serialized quant stats protobuf and initialize the internal
// data structure. This method must be called after the pass is created. // data structure. This method must be called after the pass is created.
bool ParseQuantStats(const std::string &stats_str); bool ParseQuantStats(const std::string &stats_str);

View File

@ -28,6 +28,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:error_util",
@ -74,6 +75,6 @@ tf_cc_binary(
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
], ],
) )

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@ -52,7 +53,7 @@ TfLiteStatus QuantizeModel(
} }
MLIRContext context; MLIRContext context;
context.loadAllGloballyRegisteredDialects(); context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
StatusScopedDiagnosticHandler statusHandler(&context, StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true); /*propagate=*/true);

View File

@ -37,7 +37,6 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
flatbuffers::FlatBufferBuilder* builder, flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) { tflite::ErrorReporter* error_reporter) {
MLIRContext context; MLIRContext context;
context.loadAllGloballyRegisteredDialects();
StatusScopedDiagnosticHandler statusHandler(&context, StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true); /*propagate=*/true);

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
@ -84,6 +85,11 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
: unfold_batch_matmul_(unfold_batch_matmul) {} : unfold_batch_matmul_(unfold_batch_matmul) {}
void runOnFunction() override; void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
TFL::TensorFlowLiteDialect>();
}
private: private:
bool unfold_batch_matmul_; bool unfold_batch_matmul_;
}; };

View File

@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
LstmUtilsTest() {} LstmUtilsTest() {}
void SetUp() override { void SetUp() override {
RegisterDialects();
context_ = std::make_unique<mlir::MLIRContext>(); context_ = std::make_unique<mlir::MLIRContext>();
context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
TensorFlowLiteDialect>();
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get())); builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false); fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
fused_lstm_func_cifg_ = fused_lstm_func_cifg_ =
@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test {
builder_.reset(); builder_.reset();
} }
void RegisterDialects() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<TensorFlowLiteDialect>();
}
FuncOp fused_lstm_func_; FuncOp fused_lstm_func_;
FuncOp fused_lstm_func_cifg_; FuncOp fused_lstm_func_cifg_;
FuncOp fused_ln_lstm_func_; FuncOp fused_ln_lstm_func_;