Expose a way to run an MLIR pass pipeline from python.
PiperOrigin-RevId: 277178584 Change-Id: I1d55cb99bccaf4bfebe26d432234c3db3a3b59e3
This commit is contained in:
parent
64f06f9455
commit
64d6166c50
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
%{
|
%{
|
||||||
|
|
||||||
|
#include "mlir/Parser.h"
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@ -113,6 +114,41 @@ string ExperimentalConvertSavedModelToMlir(
|
|||||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
string ExperimentalRunPassPipeline(
|
||||||
|
const string &mlir_txt,
|
||||||
|
const string &pass_pipeline,
|
||||||
|
bool show_debug_info,
|
||||||
|
TF_Status* status) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::OwningModuleRef module;
|
||||||
|
{
|
||||||
|
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||||
|
module = mlir::parseSourceString(mlir_txt, &context);
|
||||||
|
if (!module) {
|
||||||
|
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the pass_pipeline on the module.
|
||||||
|
mlir::PassManager pm(&context);
|
||||||
|
std::string error;
|
||||||
|
llvm::raw_string_ostream error_stream(error);
|
||||||
|
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
|
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||||
|
if (failed(pm.run(*module))) {
|
||||||
|
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
return MlirModuleToString(*module, show_debug_info);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace swig
|
} // namespace swig
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
@ -124,6 +160,7 @@ string ExperimentalConvertSavedModelToMlir(
|
|||||||
%unignore tensorflow::swig;
|
%unignore tensorflow::swig;
|
||||||
%unignore tensorflow::swig::ImportGraphDef;
|
%unignore tensorflow::swig::ImportGraphDef;
|
||||||
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
||||||
|
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
|
||||||
|
|
||||||
// Wrap this function
|
// Wrap this function
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -136,6 +173,11 @@ static string ExperimentalConvertSavedModelToMlir(
|
|||||||
const string &exported_names,
|
const string &exported_names,
|
||||||
bool show_debug_info,
|
bool show_debug_info,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
static string ExperimentalRunPassPipeline(
|
||||||
|
const string &mlir_txt,
|
||||||
|
const string &pass_pipeline,
|
||||||
|
bool show_debug_info,
|
||||||
|
TF_Status* status);
|
||||||
} // namespace swig
|
} // namespace swig
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
@ -151,6 +193,13 @@ def experimental_convert_saved_model_to_mlir(saved_model_path,
|
|||||||
str(exported_names).encode('utf-8'),
|
str(exported_names).encode('utf-8'),
|
||||||
show_debug_info
|
show_debug_info
|
||||||
).decode('utf-8');
|
).decode('utf-8');
|
||||||
|
|
||||||
|
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||||
|
return ExperimentalRunPassPipeline(
|
||||||
|
mlir_txt.encode('utf-8'),
|
||||||
|
pass_pipeline.encode('utf-8'),
|
||||||
|
show_debug_info
|
||||||
|
).decode('utf-8');
|
||||||
%}
|
%}
|
||||||
|
|
||||||
%unignoreall
|
%unignoreall
|
||||||
|
@ -81,6 +81,11 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False):
|
|||||||
logging.info('Saved model to: %s', save_model_path)
|
logging.info('Saved model to: %s', save_model_path)
|
||||||
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
|
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
|
||||||
save_model_path, ','.join(exported_names), show_debug_info)
|
save_model_path, ','.join(exported_names), show_debug_info)
|
||||||
|
# We don't strictly need this, but it serves as a handy sanity check
|
||||||
|
# for that API, which is otherwise a bit annoying to test.
|
||||||
|
# The canonicalization shouldn't affect these tests in any way.
|
||||||
|
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
|
||||||
|
mlir, 'canonicalize', show_debug_info)
|
||||||
print(mlir)
|
print(mlir)
|
||||||
|
|
||||||
app.run(app_main)
|
app.run(app_main)
|
||||||
|
Loading…
Reference in New Issue
Block a user