Add 'exported_names' to V1 saved model MLIR converter and update the pybindings.
PiperOrigin-RevId: 359677950 Change-Id: I9922ee1a9b1b91f208b3e12f6300f3b41ace32b2
This commit is contained in:
parent
cc8896a755
commit
d0591b7a26
@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
std::unordered_set<string> tag_set =
|
||||
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
|
||||
tensorflow::MLIRImportOptions import_options;
|
||||
import_options.upgrade_legacy = upgrade_legacy;
|
||||
auto module_or = SavedModelSignatureDefsToMlirImportLite(
|
||||
saved_model_path, tag_set, /*exported_names=*/{}, &context,
|
||||
import_options);
|
||||
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
|
||||
&context, import_options);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
// Load the saved model into a SavedModelBundle.
|
||||
|
||||
std::unordered_set<string> tag_set =
|
||||
@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
}
|
||||
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
tensorflow::MLIRImportOptions import_options;
|
||||
import_options.upgrade_legacy = upgrade_legacy;
|
||||
auto module_or =
|
||||
ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
|
||||
ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
|
||||
&context, import_options);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
|
@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status);
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status);
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
const std::string &pass_pipeline,
|
||||
|
@ -102,11 +102,13 @@ def do_test(create_signature,
|
||||
logging.info('Saved model to: %s', save_model_path)
|
||||
# TODO(b/153507667): Set the following boolean flag once the hoisting
|
||||
# variables logic from SavedModel importer is removed.
|
||||
exported_names = ''
|
||||
lift_variables = False
|
||||
upgrade_legacy = True
|
||||
if use_lite:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
save_model_path, exported_names,
|
||||
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
upgrade_legacy, show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
@ -116,7 +118,8 @@ def do_test(create_signature,
|
||||
show_debug_info)
|
||||
else:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
save_model_path, exported_names,
|
||||
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
lift_variables, upgrade_legacy, show_debug_info)
|
||||
|
||||
if canonicalize:
|
||||
|
@ -59,27 +59,29 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
|
||||
[](const std::string &saved_model_path, const std::string &tags,
|
||||
[](const std::string &saved_model_path,
|
||||
const std::string &exported_names_str, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
saved_model_path, tags, upgrade_legacy, show_debug_info,
|
||||
status.get());
|
||||
saved_model_path, exported_names_str, tags, upgrade_legacy,
|
||||
show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelV1ToMlir",
|
||||
[](const std::string &saved_model_path, const std::string &tags,
|
||||
[](const std::string &saved_model_path,
|
||||
const std::string &exported_names_str, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
|
||||
saved_model_path, tags, lift_variables, upgrade_legacy,
|
||||
show_debug_info, status.get());
|
||||
saved_model_path, exported_names_str, tags, lift_variables,
|
||||
upgrade_legacy, show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
@ -45,19 +45,23 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
str(exported_names).encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
|
||||
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path,
|
||||
exported_names, tags,
|
||||
upgrade_legacy,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(exported_names).encode('utf-8'),
|
||||
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
|
||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
|
||||
exported_names, tags,
|
||||
lift_variables, upgrade_legacy,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelV1ToMlir(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(exported_names).encode('utf-8'),
|
||||
str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
|
||||
show_debug_info)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user