Remove dependency on the MLIR global dialect registry from third_party/tensorflow/compiler/mlir/python/... (NFC)
PiperOrigin-RevId: 328241726 Change-Id: Ife6f7d0717c39000f04ad8c95f9e34286628d801
This commit is contained in:
parent
dcd11beefe
commit
8779a1bff6
tensorflow
@ -10,6 +10,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
|
||||
@ -35,6 +36,7 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
@ -41,7 +43,6 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
GraphDebugInfo debug_info;
|
||||
GraphImportConfig specs;
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
@ -86,7 +87,6 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
auto module_or = ConvertSavedModelToMlir(
|
||||
&bundle, &context, absl::Span<std::string>(exported_names));
|
||||
if (!module_or.status().ok()) {
|
||||
@ -117,7 +117,6 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
auto module_or =
|
||||
ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
|
||||
if (!module_or.status().ok()) {
|
||||
@ -153,6 +152,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
mlir::OwningModuleRef module;
|
||||
{
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
@ -167,6 +167,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
mlir::registerAllPasses();
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
|
@ -22,23 +22,25 @@ limitations under the License.
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(mlir_wrapper, m) {
|
||||
m.def("registerDialects", []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
context.getDialectRegistry().loadAll(&context);
|
||||
});
|
||||
|
||||
m.def("verify", [](std::string input) {
|
||||
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
mlir::MLIRContext ctx;
|
||||
ctx.loadAllGloballyRegisteredDialects();
|
||||
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
|
||||
ctx.getDialectRegistry().loadAll(&ctx);
|
||||
auto module = mlir::parseSourceFile(SM, &ctx);
|
||||
if (!module) {
|
||||
return false;
|
||||
|
@ -137,8 +137,8 @@ class TFProgram(object):
|
||||
"""Python wrap for a Tensorflow Program (essentially an mlir Module)."""
|
||||
|
||||
def __init__(self):
|
||||
mlir.registerDialects()
|
||||
self.ctx = mlir.MLIRContext()
|
||||
mlir.preloadTensorFlowDialects(self.ctx)
|
||||
self.builder = mlir.Builder(self.ctx)
|
||||
self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
|
||||
self.curr_func = None
|
||||
|
Loading…
Reference in New Issue
Block a user