From e2efac4eca51dc0b2e52d38c5de39e0725505b18 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 24 Aug 2020 20:05:45 -0700 Subject: [PATCH] Remove dependency on the MLIR Global Dialect registry from third_party/tensorflow/compiler/mlir/tensorflow/... (NFC) PiperOrigin-RevId: 328256230 Change-Id: I180650c53c9bbb790bead9d47ae546a3938387d1 --- .../c/c_api_unified_experimental_mlir.cc | 15 +++++---------- .../mlir/tensorflow/translate/import_model.cc | 15 ++++++++++----- .../translate/tf_mlir_translate_registration.cc | 6 +++++- .../translate/translate_tf_dialect_op.cc | 4 +++- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index c62d62a2d3d..bd21ba015bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -74,15 +75,9 @@ using tensorflow::tracing::TracingTensorHandle; namespace { -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +void RegisterDialects(mlir::MLIRContext& ctx) { + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); } Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, @@ -239,6 +234,7 @@ class MlirFunctionContext : public TracingContext { : TracingContext(kMlir), context_(std::make_unique()), builder_(context_.get()) { + RegisterDialects(*context_); // TODO(aminim) figure out the location story here module_ = ModuleOp::create(builder_.getUnknownLoc()); func_ = FuncOp::create(builder_.getUnknownLoc(), name, @@ -666,7 +662,6 @@ Status MlirFunctionContext::Finalize(OutputList* outputs, extern "C" { TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { - RegisterDialects(); return new MlirFunctionContext(fn_name); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 692d0eaf962..c539ce9b468 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -64,6 +64,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -141,6 +142,12 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, return false; } +void LoadImporterDialects(mlir::MLIRContext& context) { + // Load dialects involved in the conversion + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + context.getDialectRegistry().loadAll(&context); +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -2136,11 +2143,7 @@ StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, llvm::StringRef func_name) { - // Load dialects involved in the conversion - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - + LoadImporterDialects(*context); mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -3197,6 +3200,7 @@ Status CreateSavedModelIR( StatusOr SavedModelObjectGraphImporter::Convert( SavedModelV2Bundle* saved_model, absl::Span exported_names, mlir::MLIRContext* context, bool add_default_attributes) { + LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3276,6 +3280,7 @@ class SavedModelSignatureDefImporter { static StatusOr Convert( const SavedModelBundle& bundle, absl::Span exported_names, mlir::MLIRContext* context, bool upgrade_legacy) { + LoadImporterDialects(*context); SavedModelSignatureDefImporter importer(bundle, exported_names, context); TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy)); return importer.ConvertSignatures(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index b646e14b71d..f63cb091a09 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" @@ -86,6 +87,9 @@ static LogicalResult MlirToGraphdefTranslateFunction( } static TranslateFromMLIRRegistration mlir_to_graphdef_translate( - "mlir-to-graphdef", MlirToGraphdefTranslateFunction); + "mlir-to-graphdef", MlirToGraphdefTranslateFunction, + [](DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + }); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 5236bdeffbf..22e6559a0f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" namespace mlir { @@ -67,6 +68,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module, // Test only translation to convert a simple MLIR module with a single TF // dialect op to NodeDef. static TranslateFromMLIRRegistration translate_from_mlir_registration( - "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef); + "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef, + mlir::RegisterAllTensorFlowDialects); } // namespace mlir