Add explicit dependent dialects registration in TF MLIR
PiperOrigin-RevId: 328488901 Change-Id: Icb334bd4bf937a7634f6cd708f382026cae27051
This commit is contained in:
		
							parent
							
								
									e83b36531b
								
							
						
					
					
						commit
						0d336074b7
					
				| @ -43,6 +43,7 @@ cc_library( | |||||||
|         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", |         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", | ||||||
|         "//tensorflow/compiler/mlir/lite:tensorflow_lite", |         "//tensorflow/compiler/mlir/lite:tensorflow_lite", | ||||||
|         "//tensorflow/compiler/mlir/tensorflow", |         "//tensorflow/compiler/mlir/tensorflow", | ||||||
|  |         "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", | ||||||
|         "//tensorflow/core:lib", |         "//tensorflow/core:lib", | ||||||
|         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", |         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", | ||||||
|         "@llvm-project//mlir:MlirOptLib", |         "@llvm-project//mlir:MlirOptLib", | ||||||
|  | |||||||
| @ -30,6 +30,7 @@ limitations under the License. | |||||||
| #include "llvm/ADT/StringSwitch.h" | #include "llvm/ADT/StringSwitch.h" | ||||||
| #include "llvm/Support/Threading.h" | #include "llvm/Support/Threading.h" | ||||||
| #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
 | #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
 | ||||||
|  | #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 | ||||||
| #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 | #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 | ||||||
| #include "mlir/IR/Attributes.h"  // from @llvm-project
 | #include "mlir/IR/Attributes.h"  // from @llvm-project
 | ||||||
| #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 | #include "mlir/IR/Diagnostics.h"  // from @llvm-project
 | ||||||
| @ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; | |||||||
| 
 | 
 | ||||||
| // Legalize operations in functions.
 | // Legalize operations in functions.
 | ||||||
| class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { | class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   LegalizeTF() = default; |   LegalizeTF() = default; | ||||||
|   LegalizeTF(const LegalizeTF&) {} |   LegalizeTF(const LegalizeTF&) {} | ||||||
|  | |||||||
| @ -33,6 +33,10 @@ namespace { | |||||||
| // cond and body regions.
 | // cond and body regions.
 | ||||||
| struct LegalizeWhile | struct LegalizeWhile | ||||||
|     : public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> { |     : public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<TFL::TensorFlowLiteDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void RunOnFunction(FuncOp func); |   void RunOnFunction(FuncOp func); | ||||||
| 
 | 
 | ||||||
|   void runOnOperation() override { |   void runOnOperation() override { | ||||||
|  | |||||||
| @ -110,6 +110,10 @@ class ConvertEmbeddedLookupFunc { | |||||||
| class PrepareCompositeFunctionsPass | class PrepareCompositeFunctionsPass | ||||||
|     : public PassWrapper<PrepareCompositeFunctionsPass, |     : public PassWrapper<PrepareCompositeFunctionsPass, | ||||||
|                          OperationPass<ModuleOp>> { |                          OperationPass<ModuleOp>> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<TFL::TensorFlowLiteDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   explicit PrepareCompositeFunctionsPass() {} |   explicit PrepareCompositeFunctionsPass() {} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1318,6 +1318,7 @@ cc_library( | |||||||
|     deps = [ |     deps = [ | ||||||
|         ":convert_graphdef", |         ":convert_graphdef", | ||||||
|         ":mlir_roundtrip_flags", |         ":mlir_roundtrip_flags", | ||||||
|  |         ":tensorflow", | ||||||
|         "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", |         "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", | ||||||
|         "//tensorflow/core:core_cpu", |         "//tensorflow/core:core_cpu", | ||||||
|         "//tensorflow/core:framework", |         "//tensorflow/core:framework", | ||||||
|  | |||||||
| @ -115,6 +115,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { | |||||||
| TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) | TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) | ||||||
|     : Dialect(/*name=*/"tf_saved_model", context, |     : Dialect(/*name=*/"tf_saved_model", context, | ||||||
|               TypeID::get<TensorFlowSavedModelDialect>()) { |               TypeID::get<TensorFlowSavedModelDialect>()) { | ||||||
|  |   // The TensorFlow Dialect is needed in the verifier and other routines
 | ||||||
|  |   // associated to this dialect. It makes little sense anyway to use the
 | ||||||
|  |   // SavedModel dialect without the TensorFlow Dialect.
 | ||||||
|  |   context->loadDialect<TF::TensorFlowDialect>(); | ||||||
|  | 
 | ||||||
|   addOperations< |   addOperations< | ||||||
| #define GET_OP_LIST | #define GET_OP_LIST | ||||||
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" | ||||||
|  | |||||||
| @ -39,6 +39,10 @@ namespace { | |||||||
| 
 | 
 | ||||||
| struct ClusterFormationPass | struct ClusterFormationPass | ||||||
|     : public PassWrapper<ClusterFormationPass, FunctionPass> { |     : public PassWrapper<ClusterFormationPass, FunctionPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<tf_device::TensorFlowDeviceDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void runOnFunction() override; |   void runOnFunction() override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -615,6 +615,10 @@ class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> { | |||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> { | class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|  |     registry.insert<TF::TensorFlowDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   LegalizeHloToTf() = default; |   LegalizeHloToTf() = default; | ||||||
|   LegalizeHloToTf(const LegalizeHloToTf &) {} |   LegalizeHloToTf(const LegalizeHloToTf &) {} | ||||||
|  | |||||||
| @ -39,6 +39,10 @@ namespace { | |||||||
| 
 | 
 | ||||||
| struct ParallelizeEmbeddingParamsOpsPass | struct ParallelizeEmbeddingParamsOpsPass | ||||||
|     : public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> { |     : public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<tf_device::TensorFlowDeviceDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void runOnFunction() override; |   void runOnFunction() override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -20,6 +20,7 @@ limitations under the License. | |||||||
| #include "mlir/IR/Identifier.h"  // from @llvm-project
 | #include "mlir/IR/Identifier.h"  // from @llvm-project
 | ||||||
| #include "mlir/IR/Location.h"  // from @llvm-project
 | #include "mlir/IR/Location.h"  // from @llvm-project
 | ||||||
| #include "mlir/Pass/Pass.h"  // from @llvm-project
 | #include "mlir/Pass/Pass.h"  // from @llvm-project
 | ||||||
|  | #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" | ||||||
| #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" | #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" | ||||||
| #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" | ||||||
| #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" | #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" | ||||||
| @ -43,6 +44,10 @@ namespace tensorflow { | |||||||
| class GraphOptPass | class GraphOptPass | ||||||
|     : public mlir::PassWrapper<GraphOptPass, |     : public mlir::PassWrapper<GraphOptPass, | ||||||
|                                mlir::OperationPass<mlir::ModuleOp>> { |                                mlir::OperationPass<mlir::ModuleOp>> { | ||||||
|  |   void getDependentDialects(mlir::DialectRegistry& registry) const override { | ||||||
|  |     mlir::RegisterAllTensorFlowDialects(registry); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes) |   explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes) | ||||||
|       : passes_(std::move(passes)) {} |       : passes_(std::move(passes)) {} | ||||||
|  | |||||||
| @ -78,6 +78,10 @@ using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, | |||||||
| struct TPUClusterFormation | struct TPUClusterFormation | ||||||
|     : public TF::PerFunctionAggregateAnalysisConsumerPass< |     : public TF::PerFunctionAggregateAnalysisConsumerPass< | ||||||
|           TPUClusterFormation, TF::ResourceAliasAnalysis> { |           TPUClusterFormation, TF::ResourceAliasAnalysis> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<tf_device::TensorFlowDeviceDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void runOnFunction( |   void runOnFunction( | ||||||
|       FuncOp func, |       FuncOp func, | ||||||
|       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); |       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); | ||||||
|  | |||||||
| @ -43,6 +43,10 @@ namespace { | |||||||
| 
 | 
 | ||||||
| class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< | class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< | ||||||
|                            BreakUpIslands, TF::SideEffectAnalysis> { |                            BreakUpIslands, TF::SideEffectAnalysis> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<tf_executor::TensorFlowExecutorDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnFunction(FuncOp func, |   void runOnFunction(FuncOp func, | ||||||
|                      const TF::SideEffectAnalysis::Info& side_effect_analysis); |                      const TF::SideEffectAnalysis::Info& side_effect_analysis); | ||||||
|  | |||||||
| @ -144,8 +144,9 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, | |||||||
| 
 | 
 | ||||||
| void LoadImporterDialects(mlir::MLIRContext& context) { | void LoadImporterDialects(mlir::MLIRContext& context) { | ||||||
|   // Load dialects involved in the conversion
 |   // Load dialects involved in the conversion
 | ||||||
|   mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); |   mlir::DialectRegistry registry; | ||||||
|   context.getDialectRegistry().loadAll(&context); |   mlir::RegisterAllTensorFlowDialects(registry); | ||||||
|  |   registry.loadAll(&context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This class is used to generate new MLIR function name strings that are both
 | // This class is used to generate new MLIR function name strings that are both
 | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ limitations under the License. | |||||||
| #include "tensorflow/compiler/mlir/init_mlir.h" | #include "tensorflow/compiler/mlir/init_mlir.h" | ||||||
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" | ||||||
| #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" | #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" | ||||||
|  | #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" | ||||||
| #include "tensorflow/core/platform/init_main.h" | #include "tensorflow/core/platform/init_main.h" | ||||||
| 
 | 
 | ||||||
| int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||||
| @ -34,6 +35,7 @@ int main(int argc, char **argv) { | |||||||
|   mlir::mhlo::registerAllMhloDialects(registry); |   mlir::mhlo::registerAllMhloDialects(registry); | ||||||
|   registry.insert<mlir::shape::ShapeDialect>(); |   registry.insert<mlir::shape::ShapeDialect>(); | ||||||
|   registry.insert<mlir::TFL::TensorFlowLiteDialect>(); |   registry.insert<mlir::TFL::TensorFlowLiteDialect>(); | ||||||
|  |   registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>(); | ||||||
|   return failed( |   return failed( | ||||||
|       mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); |       mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); | ||||||
| } | } | ||||||
|  | |||||||
| @ -67,6 +67,10 @@ class UnrankedTensorStoreTestOnlyPattern | |||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| struct BufferizePass : public BufferizePassBase<BufferizePass> { | struct BufferizePass : public BufferizePassBase<BufferizePass> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<lmhlo::LmhloDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnOperation() override { |   void runOnOperation() override { | ||||||
|     OwningRewritePatternList patterns; |     OwningRewritePatternList patterns; | ||||||
|  | |||||||
| @ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry"; | |||||||
| // * std.dealloc becomes tf_framework.dealloc_raw.
 | // * std.dealloc becomes tf_framework.dealloc_raw.
 | ||||||
| class EmbedTFFrameworkPass | class EmbedTFFrameworkPass | ||||||
|     : public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> { |     : public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnOperation() override { |   void runOnOperation() override { | ||||||
|     ModuleOp m = getOperation(); |     ModuleOp m = getOperation(); | ||||||
|  | |||||||
| @ -38,6 +38,10 @@ namespace { | |||||||
| 
 | 
 | ||||||
| struct ShapeToDescriptorsPass | struct ShapeToDescriptorsPass | ||||||
|     : public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> { |     : public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|  |     registry.insert<scf::SCFDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnOperation() override { |   void runOnOperation() override { | ||||||
|     MLIRContext &ctx = getContext(); |     MLIRContext &ctx = getContext(); | ||||||
|  | |||||||
| @ -33,6 +33,10 @@ namespace { | |||||||
| 
 | 
 | ||||||
| class TestTFFrameworkToLLVMPass | class TestTFFrameworkToLLVMPass | ||||||
|     : public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> { |     : public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> { | ||||||
|  |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|  |     registry.insert<LLVM::LLVMDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnOperation() override { |   void runOnOperation() override { | ||||||
|     ModuleOp m = getOperation(); |     ModuleOp m = getOperation(); | ||||||
|  | |||||||
| @ -72,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding"; | |||||||
| 
 | 
 | ||||||
| class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { | class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { | ||||||
|   void getDependentDialects(DialectRegistry ®istry) const override { |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|     registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>(); |     registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, | ||||||
|  |                     shape::ShapeDialect, StandardOpsDialect>(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|  | |||||||
| @ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] = | |||||||
| // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
 | // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
 | ||||||
| class LegalizeTFCommunication | class LegalizeTFCommunication | ||||||
|     : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> { |     : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> { | ||||||
|  |   void getDependentDialects(DialectRegistry& registry) const override { | ||||||
|  |     registry.insert<mhlo::MhloDialect>(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   void runOnOperation() override; |   void runOnOperation() override; | ||||||
| }; | }; | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user