Remove dependency on the MLIR Global Dialect registry from third_party/tensorflow/compiler/mlir/tensorflow/... (NFC)

PiperOrigin-RevId: 328256230
Change-Id: I180650c53c9bbb790bead9d47ae546a3938387d1
This commit is contained in:
Mehdi Amini 2020-08-24 20:05:45 -07:00 committed by TensorFlower Gardener
parent 9f069e0255
commit e2efac4eca
4 changed files with 23 additions and 17 deletions

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -74,15 +75,9 @@ using tensorflow::tracing::TracingTensorHandle;
namespace {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
void RegisterDialects(mlir::MLIRContext& ctx) {
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
ctx.getDialectRegistry().loadAll(&ctx);
}
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
@ -239,6 +234,7 @@ class MlirFunctionContext : public TracingContext {
: TracingContext(kMlir),
context_(std::make_unique<MLIRContext>()),
builder_(context_.get()) {
RegisterDialects(*context_);
// TODO(aminim) figure out the location story here
module_ = ModuleOp::create(builder_.getUnknownLoc());
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
@ -666,7 +662,6 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
extern "C" {
TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
RegisterDialects();
return new MlirFunctionContext(fn_name);
}
}

View File

@ -64,6 +64,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -141,6 +142,12 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
return false;
}
void LoadImporterDialects(mlir::MLIRContext& context) {
// Load dialects involved in the conversion
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
context.getDialectRegistry().loadAll(&context);
}
// This class is used to generate new MLIR function name strings that are both
// unique in the TF function library `flib_` and unique among the name strings
// generated by the class object during its lifetime.
@ -2136,11 +2143,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const GraphImportConfig& specs, llvm::StringRef func_name) {
// Load dialects involved in the conversion
context->loadDialect<mlir::StandardOpsDialect>();
context->loadDialect<mlir::TF::TensorFlowDialect>();
context->loadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
LoadImporterDialects(*context);
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
@ -3197,6 +3200,7 @@ Status CreateSavedModelIR(
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool add_default_attributes) {
LoadImporterDialects(*context);
GraphDebugInfo dummy_debug_info;
const GraphDebugInfo& debug_info =
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
@ -3276,6 +3280,7 @@ class SavedModelSignatureDefImporter {
static StatusOr<mlir::OwningModuleRef> Convert(
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool upgrade_legacy) {
LoadImporterDialects(*context);
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy));
return importer.ConvertSignatures();

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@ -86,6 +87,9 @@ static LogicalResult MlirToGraphdefTranslateFunction(
}
static TranslateFromMLIRRegistration mlir_to_graphdef_translate(
"mlir-to-graphdef", MlirToGraphdefTranslateFunction);
"mlir-to-graphdef", MlirToGraphdefTranslateFunction,
[](DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
});
} // namespace mlir

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
namespace mlir {
@ -67,6 +68,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module,
// Test only translation to convert a simple MLIR module with a single TF
// dialect op to NodeDef.
static TranslateFromMLIRRegistration translate_from_mlir_registration(
"test-only-mlir-to-tf-nodedef", MlirToTfNodeDef);
"test-only-mlir-to-tf-nodedef", MlirToTfNodeDef,
mlir::RegisterAllTensorFlowDialects);
} // namespace mlir