Expose a way to run an MLIR pass pipeline from python.
PiperOrigin-RevId: 277178584 Change-Id: I1d55cb99bccaf4bfebe26d432234c3db3a3b59e3
This commit is contained in:
parent
64f06f9455
commit
64d6166c50
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
%{
|
||||
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@ -113,6 +114,41 @@ string ExperimentalConvertSavedModelToMlir(
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
|
||||
string ExperimentalRunPassPipeline(
|
||||
const string &mlir_txt,
|
||||
const string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status* status) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
{
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
module = mlir::parseSourceString(mlir_txt, &context);
|
||||
if (!module) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module.
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
if (failed(pm.run(*module))) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
return MlirModuleToString(*module, show_debug_info);
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -124,6 +160,7 @@ string ExperimentalConvertSavedModelToMlir(
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ImportGraphDef;
|
||||
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
||||
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
|
||||
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
@ -136,6 +173,11 @@ static string ExperimentalConvertSavedModelToMlir(
|
||||
const string &exported_names,
|
||||
bool show_debug_info,
|
||||
TF_Status* status);
|
||||
static string ExperimentalRunPassPipeline(
|
||||
const string &mlir_txt,
|
||||
const string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -151,6 +193,13 @@ def experimental_convert_saved_model_to_mlir(saved_model_path,
|
||||
str(exported_names).encode('utf-8'),
|
||||
show_debug_info
|
||||
).decode('utf-8');
|
||||
|
||||
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||
return ExperimentalRunPassPipeline(
|
||||
mlir_txt.encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'),
|
||||
show_debug_info
|
||||
).decode('utf-8');
|
||||
%}
|
||||
|
||||
%unignoreall
|
||||
|
@ -81,6 +81,11 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False):
|
||||
logging.info('Saved model to: %s', save_model_path)
|
||||
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
|
||||
save_model_path, ','.join(exported_names), show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
# The canonicalization shouldn't affect these tests in any way.
|
||||
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
|
||||
mlir, 'canonicalize', show_debug_info)
|
||||
print(mlir)
|
||||
|
||||
app.run(app_main)
|
||||
|
Loading…
Reference in New Issue
Block a user