diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i index 98db04dac00..ba5bfb98948 100644 --- a/tensorflow/compiler/mlir/python/mlir.i +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -17,6 +17,7 @@ limitations under the License. %{ +#include "mlir/Parser.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/PassManager.h" #include "llvm/Support/raw_ostream.h" @@ -113,6 +114,41 @@ string ExperimentalConvertSavedModelToMlir( return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); } + +string ExperimentalRunPassPipeline( + const string &mlir_txt, + const string &pass_pipeline, + bool show_debug_info, + TF_Status* status) { + mlir::MLIRContext context; + mlir::OwningModuleRef module; + { + mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); + module = mlir::parseSourceString(mlir_txt, &context); + if (!module) { + Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + return "// error"; + } + } + + // Run the pass_pipeline on the module. + mlir::PassManager pm(&context); + std::string error; + llvm::raw_string_ostream error_stream(error); + if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Invalid pass_pipeline: " + error_stream.str()).c_str()); + return "// error"; + } + + mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); + if (failed(pm.run(*module))) { + Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + return "// error"; + } + return MlirModuleToString(*module, show_debug_info); +} + } // namespace swig } // namespace tensorflow @@ -124,6 +160,7 @@ string ExperimentalConvertSavedModelToMlir( %unignore tensorflow::swig; %unignore tensorflow::swig::ImportGraphDef; %unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir; +%unignore tensorflow::swig::ExperimentalRunPassPipeline; // Wrap this function namespace tensorflow { @@ -136,6 +173,11 @@ static string ExperimentalConvertSavedModelToMlir( const string &exported_names, bool show_debug_info, TF_Status* status); +static string ExperimentalRunPassPipeline( + const string &mlir_txt, + const string &pass_pipeline, + bool show_debug_info, + TF_Status* status); } // namespace swig } // namespace tensorflow @@ -151,6 +193,13 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, str(exported_names).encode('utf-8'), show_debug_info ).decode('utf-8'); + +def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): + return ExperimentalRunPassPipeline( + mlir_txt.encode('utf-8'), + pass_pipeline.encode('utf-8'), + show_debug_info + ).decode('utf-8'); %} %unignoreall diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py index 77b7b3a4662..ebdbe3afa4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py @@ -81,6 +81,11 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False): logging.info('Saved model to: %s', save_model_path) mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir( save_model_path, ','.join(exported_names), show_debug_info) + # We don't strictly need this, but it serves as a handy sanity check + # for that API, which is otherwise a bit annoying to test. + # The canonicalization shouldn't affect these tests in any way. + mlir = pywrap_tensorflow.experimental_run_pass_pipeline( + mlir, 'canonicalize', show_debug_info) print(mlir) app.run(app_main)