Add explicit dependent dialects registration in TF MLIR

PiperOrigin-RevId: 328488901
Change-Id: Icb334bd4bf937a7634f6cd708f382026cae27051
This commit is contained in:
Mehdi Amini 2020-08-26 01:34:19 -07:00 committed by TensorFlower Gardener
parent e83b36531b
commit 0d336074b7
20 changed files with 72 additions and 3 deletions

View File

@ -43,6 +43,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"//tensorflow/core:lib",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:MlirOptLib",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Threading.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
@ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions.
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
}
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF&) {}

View File

@ -33,6 +33,10 @@ namespace {
// cond and body regions.
struct LegalizeWhile
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<TFL::TensorFlowLiteDialect>();
}
void RunOnFunction(FuncOp func);
void runOnOperation() override {

View File

@ -110,6 +110,10 @@ class ConvertEmbeddedLookupFunc {
class PrepareCompositeFunctionsPass
: public PassWrapper<PrepareCompositeFunctionsPass,
OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<TFL::TensorFlowLiteDialect>();
}
public:
explicit PrepareCompositeFunctionsPass() {}

View File

@ -1318,6 +1318,7 @@ cc_library(
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
":tensorflow",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -115,6 +115,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
: Dialect(/*name=*/"tf_saved_model", context,
TypeID::get<TensorFlowSavedModelDialect>()) {
// The TensorFlow Dialect is needed in the verifier and other routines
// associated to this dialect. It makes little sense anyway to use the
// SavedModel dialect without the TensorFlow Dialect.
context->loadDialect<TF::TensorFlowDialect>();
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"

View File

@ -39,6 +39,10 @@ namespace {
struct ClusterFormationPass
: public PassWrapper<ClusterFormationPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tf_device::TensorFlowDeviceDialect>();
}
void runOnFunction() override;
};

View File

@ -615,6 +615,10 @@ class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
};
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>();
}
public:
LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &) {}

View File

@ -39,6 +39,10 @@ namespace {
struct ParallelizeEmbeddingParamsOpsPass
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tf_device::TensorFlowDeviceDialect>();
}
void runOnFunction() override;
};

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@ -43,6 +44,10 @@ namespace tensorflow {
class GraphOptPass
: public mlir::PassWrapper<GraphOptPass,
mlir::OperationPass<mlir::ModuleOp>> {
void getDependentDialects(mlir::DialectRegistry& registry) const override {
mlir::RegisterAllTensorFlowDialects(registry);
}
public:
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
: passes_(std::move(passes)) {}

View File

@ -78,6 +78,10 @@ using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
struct TPUClusterFormation
: public TF::PerFunctionAggregateAnalysisConsumerPass<
TPUClusterFormation, TF::ResourceAliasAnalysis> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tf_device::TensorFlowDeviceDialect>();
}
void runOnFunction(
FuncOp func,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);

View File

@ -43,6 +43,10 @@ namespace {
class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
BreakUpIslands, TF::SideEffectAnalysis> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tf_executor::TensorFlowExecutorDialect>();
}
public:
void runOnFunction(FuncOp func,
const TF::SideEffectAnalysis::Info& side_effect_analysis);

View File

@ -144,8 +144,9 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
void LoadImporterDialects(mlir::MLIRContext& context) {
// Load dialects involved in the conversion
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
context.getDialectRegistry().loadAll(&context);
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
registry.loadAll(&context);
}
// This class is used to generate new MLIR function name strings that are both

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/core/platform/init_main.h"
int main(int argc, char **argv) {
@ -34,6 +35,7 @@ int main(int argc, char **argv) {
mlir::mhlo::registerAllMhloDialects(registry);
registry.insert<mlir::shape::ShapeDialect>();
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
return failed(
mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry));
}

View File

@ -67,6 +67,10 @@ class UnrankedTensorStoreTestOnlyPattern
};
struct BufferizePass : public BufferizePassBase<BufferizePass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect>();
}
public:
void runOnOperation() override {
OwningRewritePatternList patterns;

View File

@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry";
// * std.dealloc becomes tf_framework.dealloc_raw.
class EmbedTFFrameworkPass
: public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
}
public:
void runOnOperation() override {
ModuleOp m = getOperation();

View File

@ -38,6 +38,10 @@ namespace {
struct ShapeToDescriptorsPass
: public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
}
public:
void runOnOperation() override {
MLIRContext &ctx = getContext();

View File

@ -33,6 +33,10 @@ namespace {
class TestTFFrameworkToLLVMPass
: public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
public:
void runOnOperation() override {
ModuleOp m = getOperation();

View File

@ -72,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding";
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>();
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
shape::ShapeDialect, StandardOpsDialect>();
}
public:

View File

@ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] =
// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
class LegalizeTFCommunication
: public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mhlo::MhloDialect>();
}
public:
void runOnOperation() override;
};