Add 'exported_names' to V1 saved model MLIR converter and update the pybindings.
PiperOrigin-RevId: 359677950 Change-Id: I9922ee1a9b1b91f208b3e12f6300f3b41ace32b2
This commit is contained in:
parent
cc8896a755
commit
d0591b7a26
@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||||
const std::string &saved_model_path, const std::string &tags,
|
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
|
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||||
|
TF_Status *status) {
|
||||||
std::unordered_set<string> tag_set =
|
std::unordered_set<string> tag_set =
|
||||||
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||||
|
|
||||||
|
std::vector<string> exported_names =
|
||||||
|
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
tensorflow::MLIRImportOptions import_options;
|
tensorflow::MLIRImportOptions import_options;
|
||||||
import_options.upgrade_legacy = upgrade_legacy;
|
import_options.upgrade_legacy = upgrade_legacy;
|
||||||
auto module_or = SavedModelSignatureDefsToMlirImportLite(
|
auto module_or = SavedModelSignatureDefsToMlirImportLite(
|
||||||
saved_model_path, tag_set, /*exported_names=*/{}, &context,
|
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
|
||||||
import_options);
|
&context, import_options);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
Set_TF_Status_from_Status(status, module_or.status());
|
Set_TF_Status_from_Status(status, module_or.status());
|
||||||
return "// error";
|
return "// error";
|
||||||
@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||||
const std::string &saved_model_path, const std::string &tags,
|
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||||
TF_Status *status) {
|
bool show_debug_info, TF_Status *status) {
|
||||||
// Load the saved model into a SavedModelBundle.
|
// Load the saved model into a SavedModelBundle.
|
||||||
|
|
||||||
std::unordered_set<string> tag_set =
|
std::unordered_set<string> tag_set =
|
||||||
@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert the SavedModelBundle to an MLIR module.
|
// Convert the SavedModelBundle to an MLIR module.
|
||||||
|
std::vector<string> exported_names =
|
||||||
|
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
tensorflow::MLIRImportOptions import_options;
|
tensorflow::MLIRImportOptions import_options;
|
||||||
import_options.upgrade_legacy = upgrade_legacy;
|
import_options.upgrade_legacy = upgrade_legacy;
|
||||||
auto module_or =
|
auto module_or =
|
||||||
ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
|
ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
|
||||||
|
&context, import_options);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
Set_TF_Status_from_Status(status, module_or.status());
|
Set_TF_Status_from_Status(status, module_or.status());
|
||||||
return "// error";
|
return "// error";
|
||||||
|
@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir(
|
|||||||
// Returns:
|
// Returns:
|
||||||
// A string of textual MLIR representing the raw imported SavedModel.
|
// A string of textual MLIR representing the raw imported SavedModel.
|
||||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||||
const std::string &saved_model_path, const std::string &tags,
|
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
|
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||||
|
TF_Status *status);
|
||||||
|
|
||||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||||
//
|
//
|
||||||
@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
|||||||
// Returns:
|
// Returns:
|
||||||
// A string of textual MLIR representing the raw imported SavedModel.
|
// A string of textual MLIR representing the raw imported SavedModel.
|
||||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||||
const std::string &saved_model_path, const std::string &tags,
|
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||||
TF_Status *status);
|
bool show_debug_info, TF_Status *status);
|
||||||
|
|
||||||
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||||
const std::string &pass_pipeline,
|
const std::string &pass_pipeline,
|
||||||
|
@ -102,11 +102,13 @@ def do_test(create_signature,
|
|||||||
logging.info('Saved model to: %s', save_model_path)
|
logging.info('Saved model to: %s', save_model_path)
|
||||||
# TODO(b/153507667): Set the following boolean flag once the hoisting
|
# TODO(b/153507667): Set the following boolean flag once the hoisting
|
||||||
# variables logic from SavedModel importer is removed.
|
# variables logic from SavedModel importer is removed.
|
||||||
|
exported_names = ''
|
||||||
lift_variables = False
|
lift_variables = False
|
||||||
upgrade_legacy = True
|
upgrade_legacy = True
|
||||||
if use_lite:
|
if use_lite:
|
||||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
|
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
|
||||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
save_model_path, exported_names,
|
||||||
|
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||||
upgrade_legacy, show_debug_info)
|
upgrade_legacy, show_debug_info)
|
||||||
# We don't strictly need this, but it serves as a handy sanity check
|
# We don't strictly need this, but it serves as a handy sanity check
|
||||||
# for that API, which is otherwise a bit annoying to test.
|
# for that API, which is otherwise a bit annoying to test.
|
||||||
@ -116,7 +118,8 @@ def do_test(create_signature,
|
|||||||
show_debug_info)
|
show_debug_info)
|
||||||
else:
|
else:
|
||||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
save_model_path, exported_names,
|
||||||
|
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||||
lift_variables, upgrade_legacy, show_debug_info)
|
lift_variables, upgrade_legacy, show_debug_info)
|
||||||
|
|
||||||
if canonicalize:
|
if canonicalize:
|
||||||
|
@ -59,27 +59,29 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
|
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
|
||||||
[](const std::string &saved_model_path, const std::string &tags,
|
[](const std::string &saved_model_path,
|
||||||
|
const std::string &exported_names_str, const std::string &tags,
|
||||||
bool upgrade_legacy, bool show_debug_info) {
|
bool upgrade_legacy, bool show_debug_info) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
std::string output =
|
std::string output =
|
||||||
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
|
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
|
||||||
saved_model_path, tags, upgrade_legacy, show_debug_info,
|
saved_model_path, exported_names_str, tags, upgrade_legacy,
|
||||||
status.get());
|
show_debug_info, status.get());
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("ExperimentalConvertSavedModelV1ToMlir",
|
m.def("ExperimentalConvertSavedModelV1ToMlir",
|
||||||
[](const std::string &saved_model_path, const std::string &tags,
|
[](const std::string &saved_model_path,
|
||||||
|
const std::string &exported_names_str, const std::string &tags,
|
||||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
|
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
std::string output =
|
std::string output =
|
||||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
|
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
|
||||||
saved_model_path, tags, lift_variables, upgrade_legacy,
|
saved_model_path, exported_names_str, tags, lift_variables,
|
||||||
show_debug_info, status.get());
|
upgrade_legacy, show_debug_info, status.get());
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
@ -45,19 +45,23 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
|||||||
str(exported_names).encode('utf-8'), show_debug_info)
|
str(exported_names).encode('utf-8'), show_debug_info)
|
||||||
|
|
||||||
|
|
||||||
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
|
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path,
|
||||||
|
exported_names, tags,
|
||||||
upgrade_legacy,
|
upgrade_legacy,
|
||||||
show_debug_info):
|
show_debug_info):
|
||||||
return ExperimentalConvertSavedModelV1ToMlirLite(
|
return ExperimentalConvertSavedModelV1ToMlirLite(
|
||||||
str(saved_model_path).encode('utf-8'),
|
str(saved_model_path).encode('utf-8'),
|
||||||
|
str(exported_names).encode('utf-8'),
|
||||||
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
|
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
|
||||||
|
|
||||||
|
|
||||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
|
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
|
||||||
|
exported_names, tags,
|
||||||
lift_variables, upgrade_legacy,
|
lift_variables, upgrade_legacy,
|
||||||
show_debug_info):
|
show_debug_info):
|
||||||
return ExperimentalConvertSavedModelV1ToMlir(
|
return ExperimentalConvertSavedModelV1ToMlir(
|
||||||
str(saved_model_path).encode('utf-8'),
|
str(saved_model_path).encode('utf-8'),
|
||||||
|
str(exported_names).encode('utf-8'),
|
||||||
str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
|
str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
|
||||||
show_debug_info)
|
show_debug_info)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user