Remove dependency on the MLIR global dialect registry from third_party/tensorflow/compiler/mlir/python/... (NFC)

PiperOrigin-RevId: 328241726
Change-Id: Ife6f7d0717c39000f04ad8c95f9e34286628d801
This commit is contained in:
Mehdi Amini 2020-08-24 17:49:22 -07:00 committed by TensorFlower Gardener
parent dcd11beefe
commit 8779a1bff6
4 changed files with 14 additions and 9 deletions
tensorflow
compiler/mlir/python
python/tf_program

View File

@ -10,6 +10,7 @@ cc_library(
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
@ -35,6 +36,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
alwayslink = 1,
)

View File

@ -16,11 +16,13 @@ limitations under the License.
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@ -41,7 +43,6 @@ std::string ImportGraphDef(const std::string &proto,
GraphDebugInfo debug_info;
GraphImportConfig specs;
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
@ -86,7 +87,6 @@ std::string ExperimentalConvertSavedModelToMlir(
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or = ConvertSavedModelToMlir(
&bundle, &context, absl::Span<std::string>(exported_names));
if (!module_or.status().ok()) {
@ -117,7 +117,6 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or =
ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
if (!module_or.status().ok()) {
@ -153,6 +152,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
bool show_debug_info,
TF_Status *status) {
mlir::MLIRContext context;
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
@ -167,6 +167,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
mlir::registerAllPasses();
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());

View File

@ -22,23 +22,25 @@ limitations under the License.
#include "mlir/Parser.h" // from @llvm-project
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("registerDialects", []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
context.getDialectRegistry().loadAll(&context);
});
m.def("verify", [](std::string input) {
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
mlir::MLIRContext ctx;
ctx.loadAllGloballyRegisteredDialects();
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
ctx.getDialectRegistry().loadAll(&ctx);
auto module = mlir::parseSourceFile(SM, &ctx);
if (!module) {
return false;

View File

@ -137,8 +137,8 @@ class TFProgram(object):
"""Python wrap for a Tensorflow Program (essentially an mlir Module)."""
def __init__(self):
mlir.registerDialects()
self.ctx = mlir.MLIRContext()
mlir.preloadTensorFlowDialects(self.ctx)
self.builder = mlir.Builder(self.ctx)
self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
self.curr_func = None