Add 'exported_names' to V1 saved model MLIR converter and update the pybindings.

PiperOrigin-RevId: 359677950
Change-Id: I9922ee1a9b1b91f208b3e12f6300f3b41ace32b2
This commit is contained in:
A. Unique TensorFlower 2021-02-25 20:46:16 -08:00 committed by TensorFlower Gardener
parent cc8896a755
commit d0591b7a26
5 changed files with 39 additions and 24 deletions

View File

@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir(
}
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
TF_Status *status) {
std::unordered_set<string> tag_set =
absl::StrSplit(tags, ',', absl::SkipEmpty());
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
tensorflow::MLIRImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
auto module_or = SavedModelSignatureDefsToMlirImportLite(
saved_model_path, tag_set, /*exported_names=*/{}, &context,
import_options);
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
&context, import_options);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
}
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
TF_Status *status) {
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool lift_variables, bool upgrade_legacy,
bool show_debug_info, TF_Status *status) {
// Load the saved model into a SavedModelBundle.
std::unordered_set<string> tag_set =
@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
}
// Convert the SavedModelBundle to an MLIR module.
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
tensorflow::MLIRImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
auto module_or =
ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
&context, import_options);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";

View File

@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir(
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
TF_Status *status);
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
TF_Status *status);
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool lift_variables, bool upgrade_legacy,
bool show_debug_info, TF_Status *status);
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
const std::string &pass_pipeline,

View File

@ -102,11 +102,13 @@ def do_test(create_signature,
logging.info('Saved model to: %s', save_model_path)
# TODO(b/153507667): Set the following boolean flag once the hoisting
# variables logic from SavedModel importer is removed.
exported_names = ''
lift_variables = False
upgrade_legacy = True
if use_lite:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
save_model_path, exported_names,
','.join([tf.saved_model.tag_constants.SERVING]),
upgrade_legacy, show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
@ -116,7 +118,8 @@ def do_test(create_signature,
show_debug_info)
else:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
save_model_path, exported_names,
','.join([tf.saved_model.tag_constants.SERVING]),
lift_variables, upgrade_legacy, show_debug_info)
if canonicalize:

View File

@ -59,27 +59,29 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
});
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
[](const std::string &saved_model_path, const std::string &tags,
[](const std::string &saved_model_path,
const std::string &exported_names_str, const std::string &tags,
bool upgrade_legacy, bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
saved_model_path, tags, upgrade_legacy, show_debug_info,
status.get());
saved_model_path, exported_names_str, tags, upgrade_legacy,
show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelV1ToMlir",
[](const std::string &saved_model_path, const std::string &tags,
[](const std::string &saved_model_path,
const std::string &exported_names_str, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
saved_model_path, tags, lift_variables, upgrade_legacy,
show_debug_info, status.get());
saved_model_path, exported_names_str, tags, lift_variables,
upgrade_legacy, show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});

View File

@ -45,19 +45,23 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
str(exported_names).encode('utf-8'), show_debug_info)
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path,
exported_names, tags,
upgrade_legacy,
show_debug_info):
return ExperimentalConvertSavedModelV1ToMlirLite(
str(saved_model_path).encode('utf-8'),
str(exported_names).encode('utf-8'),
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
exported_names, tags,
lift_variables, upgrade_legacy,
show_debug_info):
return ExperimentalConvertSavedModelV1ToMlir(
str(saved_model_path).encode('utf-8'),
str(exported_names).encode('utf-8'),
str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
show_debug_info)