Expose a way to run an MLIR pass pipeline from python.

PiperOrigin-RevId: 277178584
Change-Id: I1d55cb99bccaf4bfebe26d432234c3db3a3b59e3
This commit is contained in:
Sean Silva 2019-10-28 17:31:53 -07:00 committed by TensorFlower Gardener
parent 64f06f9455
commit 64d6166c50
2 changed files with 54 additions and 0 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
%{
#include "mlir/Parser.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/raw_ostream.h"
@ -113,6 +114,41 @@ string ExperimentalConvertSavedModelToMlir(
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
string ExperimentalRunPassPipeline(
const string &mlir_txt,
const string &pass_pipeline,
bool show_debug_info,
TF_Status* status) {
mlir::MLIRContext context;
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
module = mlir::parseSourceString(mlir_txt, &context);
if (!module) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
}
// Run the pass_pipeline on the module.
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
if (failed(pm.run(*module))) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
return MlirModuleToString(*module, show_debug_info);
}
} // namespace swig
} // namespace tensorflow
@ -124,6 +160,7 @@ string ExperimentalConvertSavedModelToMlir(
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
// Wrap this function
namespace tensorflow {
@ -136,6 +173,11 @@ static string ExperimentalConvertSavedModelToMlir(
const string &exported_names,
bool show_debug_info,
TF_Status* status);
static string ExperimentalRunPassPipeline(
const string &mlir_txt,
const string &pass_pipeline,
bool show_debug_info,
TF_Status* status);
} // namespace swig
} // namespace tensorflow
@ -151,6 +193,13 @@ def experimental_convert_saved_model_to_mlir(saved_model_path,
str(exported_names).encode('utf-8'),
show_debug_info
).decode('utf-8');
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
return ExperimentalRunPassPipeline(
mlir_txt.encode('utf-8'),
pass_pipeline.encode('utf-8'),
show_debug_info
).decode('utf-8');
%}
%unignoreall

View File

@ -81,6 +81,11 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False):
logging.info('Saved model to: %s', save_model_path)
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
save_model_path, ','.join(exported_names), show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
mlir, 'canonicalize', show_debug_info)
print(mlir)
app.run(app_main)