From d0591b7a269a682d62eedf7e759acd8052a7164e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Feb 2021 20:46:16 -0800 Subject: [PATCH] Add 'exported_names' to V1 saved model MLIR converter and update the pybindings. PiperOrigin-RevId: 359677950 Change-Id: I9922ee1a9b1b91f208b3e12f6300f3b41ace32b2 --- tensorflow/compiler/mlir/python/mlir.cc | 23 +++++++++++-------- tensorflow/compiler/mlir/python/mlir.h | 11 +++++---- .../tests/tf_saved_model/common_v1.py | 7 ++++-- tensorflow/python/mlir_wrapper.cc | 14 ++++++----- tensorflow/python/pywrap_mlir.py | 8 +++++-- 5 files changed, 39 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index c0d80c64e04..b39caa285c4 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir( } std::string ExperimentalConvertSavedModelV1ToMlirLite( - const std::string &saved_model_path, const std::string &tags, - bool upgrade_legacy, bool show_debug_info, TF_Status *status) { + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool upgrade_legacy, bool show_debug_info, + TF_Status *status) { std::unordered_set tag_set = absl::StrSplit(tags, ',', absl::SkipEmpty()); + std::vector exported_names = + absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; tensorflow::MLIRImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; auto module_or = SavedModelSignatureDefsToMlirImportLite( - saved_model_path, tag_set, /*exported_names=*/{}, &context, - import_options); + saved_model_path, tag_set, absl::Span(exported_names), + &context, import_options); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; @@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( } std::string ExperimentalConvertSavedModelV1ToMlir( - const std::string &saved_model_path, const std::string &tags, - bool lift_variables, bool upgrade_legacy, bool show_debug_info, - TF_Status *status) { + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool lift_variables, bool upgrade_legacy, + bool show_debug_info, TF_Status *status) { // Load the saved model into a SavedModelBundle. std::unordered_set tag_set = @@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir( } // Convert the SavedModelBundle to an MLIR module. - + std::vector exported_names = + absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; tensorflow::MLIRImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; auto module_or = - ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options); + ConvertSavedModelV1ToMlir(bundle, absl::Span(exported_names), + &context, import_options); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index 54a0b96cb16..af443cc6593 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir( // Returns: // A string of textual MLIR representing the raw imported SavedModel. std::string ExperimentalConvertSavedModelV1ToMlirLite( - const std::string &saved_model_path, const std::string &tags, - bool upgrade_legacy, bool show_debug_info, TF_Status *status); + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool upgrade_legacy, bool show_debug_info, + TF_Status *status); // Load a SavedModel V1 and return a textual MLIR string corresponding to it. // @@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( // Returns: // A string of textual MLIR representing the raw imported SavedModel. std::string ExperimentalConvertSavedModelV1ToMlir( - const std::string &saved_model_path, const std::string &tags, - bool lift_variables, bool upgrade_legacy, bool show_debug_info, - TF_Status *status); + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool lift_variables, bool upgrade_legacy, + bool show_debug_info, TF_Status *status); std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, const std::string &pass_pipeline, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index 68cb3258b82..504d22c4541 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -102,11 +102,13 @@ def do_test(create_signature, logging.info('Saved model to: %s', save_model_path) # TODO(b/153507667): Set the following boolean flag once the hoisting # variables logic from SavedModel importer is removed. + exported_names = '' lift_variables = False upgrade_legacy = True if use_lite: mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite( - save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + save_model_path, exported_names, + ','.join([tf.saved_model.tag_constants.SERVING]), upgrade_legacy, show_debug_info) # We don't strictly need this, but it serves as a handy sanity check # for that API, which is otherwise a bit annoying to test. @@ -116,7 +118,8 @@ def do_test(create_signature, show_debug_info) else: mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( - save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + save_model_path, exported_names, + ','.join([tf.saved_model.tag_constants.SERVING]), lift_variables, upgrade_legacy, show_debug_info) if canonicalize: diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 91ac33956a2..96de0ce47ad 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -59,27 +59,29 @@ PYBIND11_MODULE(_pywrap_mlir, m) { }); m.def("ExperimentalConvertSavedModelV1ToMlirLite", - [](const std::string &saved_model_path, const std::string &tags, + [](const std::string &saved_model_path, + const std::string &exported_names_str, const std::string &tags, bool upgrade_legacy, bool show_debug_info) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); std::string output = tensorflow::ExperimentalConvertSavedModelV1ToMlirLite( - saved_model_path, tags, upgrade_legacy, show_debug_info, - status.get()); + saved_model_path, exported_names_str, tags, upgrade_legacy, + show_debug_info, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return output; }); m.def("ExperimentalConvertSavedModelV1ToMlir", - [](const std::string &saved_model_path, const std::string &tags, + [](const std::string &saved_model_path, + const std::string &exported_names_str, const std::string &tags, bool lift_variables, bool upgrade_legacy, bool show_debug_info) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); std::string output = tensorflow::ExperimentalConvertSavedModelV1ToMlir( - saved_model_path, tags, lift_variables, upgrade_legacy, - show_debug_info, status.get()); + saved_model_path, exported_names_str, tags, lift_variables, + upgrade_legacy, show_debug_info, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return output; }); diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py index 3798c9a1670..ce6076a0ea5 100644 --- a/tensorflow/python/pywrap_mlir.py +++ b/tensorflow/python/pywrap_mlir.py @@ -45,19 +45,23 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names, str(exported_names).encode('utf-8'), show_debug_info) -def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags, +def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, + exported_names, tags, upgrade_legacy, show_debug_info): return ExperimentalConvertSavedModelV1ToMlirLite( str(saved_model_path).encode('utf-8'), + str(exported_names).encode('utf-8'), str(tags).encode('utf-8'), upgrade_legacy, show_debug_info) -def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags, +def experimental_convert_saved_model_v1_to_mlir(saved_model_path, + exported_names, tags, lift_variables, upgrade_legacy, show_debug_info): return ExperimentalConvertSavedModelV1ToMlir( str(saved_model_path).encode('utf-8'), + str(exported_names).encode('utf-8'), str(tags).encode('utf-8'), lift_variables, upgrade_legacy, show_debug_info)