Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...
PiperOrigin-RevId: 328109152 Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
This commit is contained in:
parent
227b34875f
commit
051ed1cbfd
@ -760,7 +760,7 @@ tf_cc_binary(
|
||||
deps = [
|
||||
":flatbuffer_translate_registeration",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
],
|
||||
)
|
||||
|
||||
@ -812,7 +812,7 @@ tf_cc_binary(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -836,19 +836,18 @@ tf_cc_binary(
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
":flatbuffer_translate_registeration",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Support",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -875,7 +874,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -909,7 +908,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -30,12 +30,16 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
@ -98,7 +102,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Load the MLIR module.
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
context.getDialectRegistry()
|
||||
.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
|
||||
mlir::StandardOpsDialect>();
|
||||
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
||||
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
|
||||
|
@ -49,7 +49,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const GraphDef& input,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
GraphImportConfig specs;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
|
@ -122,7 +122,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
// Parse input arrays.
|
||||
|
@ -62,6 +62,10 @@ class ImportQuantStatsPass
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<quant::QuantizationDialect>();
|
||||
}
|
||||
|
||||
// Parses the serialized quant stats protobuf and initialize the internal
|
||||
// data structure. This method must be called after the pass is created.
|
||||
bool ParseQuantStats(const std::string &stats_str);
|
||||
|
@ -28,6 +28,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
@ -74,6 +75,6 @@ tf_cc_binary(
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
@ -52,7 +53,7 @@ TfLiteStatus QuantizeModel(
|
||||
}
|
||||
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
|
||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||
/*propagate=*/true);
|
||||
|
||||
|
@ -37,7 +37,6 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter) {
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||
/*propagate=*/true);
|
||||
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
@ -84,6 +85,11 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
: unfold_batch_matmul_(unfold_batch_matmul) {}
|
||||
void runOnFunction() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
|
||||
TFL::TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
private:
|
||||
bool unfold_batch_matmul_;
|
||||
};
|
||||
|
@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
|
||||
LstmUtilsTest() {}
|
||||
|
||||
void SetUp() override {
|
||||
RegisterDialects();
|
||||
context_ = std::make_unique<mlir::MLIRContext>();
|
||||
context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
|
||||
TensorFlowLiteDialect>();
|
||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
|
||||
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
|
||||
fused_lstm_func_cifg_ =
|
||||
@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test {
|
||||
builder_.reset();
|
||||
}
|
||||
|
||||
void RegisterDialects() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
FuncOp fused_lstm_func_;
|
||||
FuncOp fused_lstm_func_cifg_;
|
||||
FuncOp fused_ln_lstm_func_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user