Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...
PiperOrigin-RevId: 328109152 Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
This commit is contained in:
parent
227b34875f
commit
051ed1cbfd
@ -760,7 +760,7 @@ tf_cc_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_registeration",
|
":flatbuffer_translate_registeration",
|
||||||
# TODO(b/155809683): Link only necessary dialects.
|
# TODO(b/155809683): Link only necessary dialects.
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -812,7 +812,7 @@ tf_cc_binary(
|
|||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
# TODO(b/155809683): Link only necessary dialects.
|
# TODO(b/155809683): Link only necessary dialects.
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
@ -836,19 +836,18 @@ tf_cc_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_lib",
|
":flatbuffer_translate_lib",
|
||||||
":flatbuffer_translate_registeration",
|
":flatbuffer_translate_registeration",
|
||||||
"@com_google_absl//absl/strings",
|
":tensorflow_lite",
|
||||||
"@llvm-project//llvm:Support",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
# TODO(b/155809683): Link only necessary dialects.
|
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
|
||||||
"@llvm-project//mlir:IR",
|
|
||||||
"@llvm-project//mlir:Parser",
|
|
||||||
"@llvm-project//mlir:Support",
|
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/delegates/flex:delegate",
|
"//tensorflow/lite/delegates/flex:delegate",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Parser",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -875,7 +874,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
@ -909,7 +908,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Parser",
|
"@llvm-project//mlir:Parser",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -30,12 +30,16 @@ limitations under the License.
|
|||||||
#include "llvm/Support/MemoryBuffer.h"
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
#include "llvm/Support/SMLoc.h"
|
#include "llvm/Support/SMLoc.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/Parser.h" // from @llvm-project
|
#include "mlir/Parser.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||||
@ -98,7 +102,10 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
// Load the MLIR module.
|
// Load the MLIR module.
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
context.loadAllGloballyRegisteredDialects();
|
context.getDialectRegistry()
|
||||||
|
.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
|
||||||
|
mlir::StandardOpsDialect>();
|
||||||
|
|
||||||
llvm::SourceMgr source_mgr;
|
llvm::SourceMgr source_mgr;
|
||||||
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
||||||
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
|
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
|
||||||
|
@ -49,7 +49,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
const GraphDef& input,
|
const GraphDef& input,
|
||||||
string* result) {
|
string* result) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
context.loadAllGloballyRegisteredDialects();
|
|
||||||
GraphImportConfig specs;
|
GraphImportConfig specs;
|
||||||
mlir::TFL::QuantizationSpecs quant_specs;
|
mlir::TFL::QuantizationSpecs quant_specs;
|
||||||
|
|
||||||
|
@ -122,7 +122,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||||
string* result) {
|
string* result) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
context.loadAllGloballyRegisteredDialects();
|
|
||||||
mlir::TFL::QuantizationSpecs quant_specs;
|
mlir::TFL::QuantizationSpecs quant_specs;
|
||||||
|
|
||||||
// Parse input arrays.
|
// Parse input arrays.
|
||||||
|
@ -62,6 +62,10 @@ class ImportQuantStatsPass
|
|||||||
|
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<quant::QuantizationDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
// Parses the serialized quant stats protobuf and initialize the internal
|
// Parses the serialized quant stats protobuf and initialize the internal
|
||||||
// data structure. This method must be called after the pass is created.
|
// data structure. This method must be called after the pass is created.
|
||||||
bool ParseQuantStats(const std::string &stats_str);
|
bool ParseQuantStats(const std::string &stats_str);
|
||||||
|
@ -28,6 +28,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/mlir/lite:common",
|
"//tensorflow/compiler/mlir/lite:common",
|
||||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
|
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
|
||||||
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
@ -74,6 +75,6 @@ tf_cc_binary(
|
|||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
@ -52,7 +53,7 @@ TfLiteStatus QuantizeModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
context.loadAllGloballyRegisteredDialects();
|
context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
|
||||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||||
/*propagate=*/true);
|
/*propagate=*/true);
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
|||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
tflite::ErrorReporter* error_reporter) {
|
tflite::ErrorReporter* error_reporter) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
context.loadAllGloballyRegisteredDialects();
|
|
||||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||||
/*propagate=*/true);
|
/*propagate=*/true);
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
@ -84,6 +85,11 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
|||||||
: unfold_batch_matmul_(unfold_batch_matmul) {}
|
: unfold_batch_matmul_(unfold_batch_matmul) {}
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
|
||||||
|
TFL::TensorFlowLiteDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool unfold_batch_matmul_;
|
bool unfold_batch_matmul_;
|
||||||
};
|
};
|
||||||
|
@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
|
|||||||
LstmUtilsTest() {}
|
LstmUtilsTest() {}
|
||||||
|
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
RegisterDialects();
|
|
||||||
context_ = std::make_unique<mlir::MLIRContext>();
|
context_ = std::make_unique<mlir::MLIRContext>();
|
||||||
|
context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
|
||||||
|
TensorFlowLiteDialect>();
|
||||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
|
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
|
||||||
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
|
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
|
||||||
fused_lstm_func_cifg_ =
|
fused_lstm_func_cifg_ =
|
||||||
@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test {
|
|||||||
builder_.reset();
|
builder_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterDialects() {
|
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
|
||||||
mlir::registerDialect<TensorFlowLiteDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncOp fused_lstm_func_;
|
FuncOp fused_lstm_func_;
|
||||||
FuncOp fused_lstm_func_cifg_;
|
FuncOp fused_lstm_func_cifg_;
|
||||||
FuncOp fused_ln_lstm_func_;
|
FuncOp fused_ln_lstm_func_;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user