Add explicit dependent dialects registration in TF MLIR
PiperOrigin-RevId: 328488901 Change-Id: Icb334bd4bf937a7634f6cd708f382026cae27051
This commit is contained in:
parent
e83b36531b
commit
0d336074b7
@ -43,6 +43,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Threading.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
@ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||
|
||||
// Legalize operations in functions.
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF&) {}
|
||||
|
@ -33,6 +33,10 @@ namespace {
|
||||
// cond and body regions.
|
||||
struct LegalizeWhile
|
||||
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<TFL::TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
void RunOnFunction(FuncOp func);
|
||||
|
||||
void runOnOperation() override {
|
||||
|
@ -110,6 +110,10 @@ class ConvertEmbeddedLookupFunc {
|
||||
class PrepareCompositeFunctionsPass
|
||||
: public PassWrapper<PrepareCompositeFunctionsPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<TFL::TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
explicit PrepareCompositeFunctionsPass() {}
|
||||
|
||||
|
@ -1318,6 +1318,7 @@ cc_library(
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -115,6 +115,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
||||
TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf_saved_model", context,
|
||||
TypeID::get<TensorFlowSavedModelDialect>()) {
|
||||
// The TensorFlow Dialect is needed in the verifier and other routines
|
||||
// associated to this dialect. It makes little sense anyway to use the
|
||||
// SavedModel dialect without the TensorFlow Dialect.
|
||||
context->loadDialect<TF::TensorFlowDialect>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
||||
|
@ -39,6 +39,10 @@ namespace {
|
||||
|
||||
struct ClusterFormationPass
|
||||
: public PassWrapper<ClusterFormationPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -615,6 +615,10 @@ class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
};
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TF::TensorFlowDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
LegalizeHloToTf() = default;
|
||||
LegalizeHloToTf(const LegalizeHloToTf &) {}
|
||||
|
@ -39,6 +39,10 @@ namespace {
|
||||
|
||||
struct ParallelizeEmbeddingParamsOpsPass
|
||||
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
@ -43,6 +44,10 @@ namespace tensorflow {
|
||||
class GraphOptPass
|
||||
: public mlir::PassWrapper<GraphOptPass,
|
||||
mlir::OperationPass<mlir::ModuleOp>> {
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
||||
: passes_(std::move(passes)) {}
|
||||
|
@ -78,6 +78,10 @@ using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
||||
struct TPUClusterFormation
|
||||
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||
TPUClusterFormation, TF::ResourceAliasAnalysis> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction(
|
||||
FuncOp func,
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
|
||||
|
@ -43,6 +43,10 @@ namespace {
|
||||
|
||||
class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||
BreakUpIslands, TF::SideEffectAnalysis> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<tf_executor::TensorFlowExecutorDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnFunction(FuncOp func,
|
||||
const TF::SideEffectAnalysis::Info& side_effect_analysis);
|
||||
|
@ -144,8 +144,9 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
|
||||
|
||||
void LoadImporterDialects(mlir::MLIRContext& context) {
|
||||
// Load dialects involved in the conversion
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
context.getDialectRegistry().loadAll(&context);
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
registry.loadAll(&context);
|
||||
}
|
||||
|
||||
// This class is used to generate new MLIR function name strings that are both
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
@ -34,6 +35,7 @@ int main(int argc, char **argv) {
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
registry.insert<mlir::shape::ShapeDialect>();
|
||||
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
|
||||
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
|
||||
return failed(
|
||||
mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry));
|
||||
}
|
||||
|
@ -67,6 +67,10 @@ class UnrankedTensorStoreTestOnlyPattern
|
||||
};
|
||||
|
||||
struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<lmhlo::LmhloDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry";
|
||||
// * std.dealloc becomes tf_framework.dealloc_raw.
|
||||
class EmbedTFFrameworkPass
|
||||
: public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp m = getOperation();
|
||||
|
@ -38,6 +38,10 @@ namespace {
|
||||
|
||||
struct ShapeToDescriptorsPass
|
||||
: public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<scf::SCFDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext &ctx = getContext();
|
||||
|
@ -33,6 +33,10 @@ namespace {
|
||||
|
||||
class TestTFFrameworkToLLVMPass
|
||||
: public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp m = getOperation();
|
||||
|
@ -72,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>();
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
shape::ShapeDialect, StandardOpsDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] =
|
||||
// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
|
||||
class LegalizeTFCommunication
|
||||
: public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user