From a407b1f41f8f4afab54b6ff745da04e5ac376681 Mon Sep 17 00:00:00 2001
From: Kuangyuan Chen <chky@google.com>
Date: Thu, 5 Nov 2020 16:02:44 -0800
Subject: [PATCH] Add the entry point for SavedModelSignatureDefImporterLite in
 tf-mlir-translate and relevant python wrappers.

PiperOrigin-RevId: 340945906
Change-Id: I54697b98c18065f829f7f85383512b4c1a460a22
---
 tensorflow/compiler/mlir/python/BUILD         |  1 +
 tensorflow/compiler/mlir/python/mlir.cc       | 20 ++++++++++++++++
 tensorflow/compiler/mlir/python/mlir.h        | 15 ++++++++++++
 tensorflow/compiler/mlir/tensorflow/BUILD     |  2 +-
 .../tests/tf_saved_model/common_v1.py         | 24 +++++++++++++++----
 .../tf_saved_model/hash_table_asset_v1.py     |  6 ++---
 .../mlir/tensorflow/translate/import_model.cc |  8 +++++++
 .../mlir/tensorflow/translate/import_model.h  | 10 ++++++++
 .../tensorflow/translate/tf_mlir_translate.cc | 24 +++++++++++++++++++
 .../tensorflow/translate/tf_mlir_translate.h  | 10 ++++++++
 .../compiler/mlir/tf_mlir_translate_main.cc   | 23 +++++++++++++++---
 tensorflow/python/mlir_wrapper.cc             | 13 ++++++++++
 tensorflow/python/pywrap_mlir.py              | 14 +++++++----
 .../tools/def_file_filter/symbols_pybind.txt  |  1 +
 14 files changed, 156 insertions(+), 15 deletions(-)

diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 502695acd40..0bc2acf7e4a 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -40,6 +40,7 @@ cc_library(
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
+        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
     ],
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
index 066726593a7..94fb78b3bd3 100644
--- a/tensorflow/compiler/mlir/python/mlir.cc
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
 #include "tensorflow/core/framework/function.h"
@@ -148,6 +149,25 @@ std::string ExperimentalConvertSavedModelToMlir(
   return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
 }
 
+std::string ExperimentalConvertSavedModelV1ToMlirLite(
+    const std::string &saved_model_path, const std::string &tags,
+    bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
+  std::unordered_set<string> tag_set =
+      absl::StrSplit(tags, ',', absl::SkipEmpty());
+
+  mlir::MLIRContext context;
+
+  auto module_or = SavedModelSignatureDefsToMlirImportLite(
+      saved_model_path, tag_set, /*exported_names=*/{}, &context,
+      upgrade_legacy);
+  if (!module_or.status().ok()) {
+    Set_TF_Status_from_Status(status, module_or.status());
+    return "// error";
+  }
+
+  return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
+}
+
 std::string ExperimentalConvertSavedModelV1ToMlir(
     const std::string &saved_model_path, const std::string &tags,
     bool lift_variables, bool upgrade_legacy, bool show_debug_info,
diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h
index 6133068a5e8..37560d04bf8 100644
--- a/tensorflow/compiler/mlir/python/mlir.h
+++ b/tensorflow/compiler/mlir/python/mlir.h
@@ -55,6 +55,21 @@ std::string ExperimentalConvertSavedModelToMlir(
     const std::string &saved_model_path, const std::string &exported_names_str,
     bool show_debug_info, TF_Status *status);
 
+// Load a SavedModel V1 and return a textual MLIR string corresponding to it
+// without any MLIR graph transformation.
+//
+// Args:
+//   saved_model_path: File path from which to load the SavedModel.
+//   tags: Tags to identify MetaGraphDef that need to be loaded.
+//   upgrade_legacy: Boolean flag that indicates whether to upgrade legacy
+//                   graphs
+//
+// Returns:
+//   A string of textual MLIR representing the raw imported SavedModel.
+std::string ExperimentalConvertSavedModelV1ToMlirLite(
+    const std::string &saved_model_path, const std::string &tags,
+    bool upgrade_legacy, bool show_debug_info, TF_Status *status);
+
 // Load a SavedModel V1 and return a textual MLIR string corresponding to it.
 //
 // Args:
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index d02f0373c79..443c651a8c4 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1509,6 +1509,7 @@ cc_library(
         ":mangling_util",
         ":mlir_roundtrip_flags",
         "//tensorflow/cc/saved_model:bundle_v2",
+        "//tensorflow/cc/saved_model:reader",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_proto_parsing",
@@ -1523,7 +1524,6 @@ cc_library(
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
index 7a61b4b4f6a..68cb3258b82 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
@@ -46,7 +46,10 @@ def set_tf_options():
 # This function needs to take a "create_module_fn", as opposed to just the
 # module itself, because the creation of the module has to be delayed until
 # after absl and tensorflow have run various initialization steps.
-def do_test(create_signature, canonicalize=False, show_debug_info=False):
+def do_test(create_signature,
+            canonicalize=False,
+            show_debug_info=False,
+            use_lite=False):
   """Runs test.
 
   1. Performs absl and tf "main"-like initialization that must run before almost
@@ -65,6 +68,8 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
       MLIR.
     canonicalize: If true, canonicalizer will be run on the resulting MLIR.
     show_debug_info: If true, shows debug locations in the resulting MLIR.
+    use_lite: If true, importer will not do any graph transformation such as
+      lift variables.
   """
 
   # Make LOG(ERROR) in C++ code show up on the console.
@@ -99,9 +104,20 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
     #                    variables logic from SavedModel importer is removed.
     lift_variables = False
     upgrade_legacy = True
-    mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
-        save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
-        lift_variables, upgrade_legacy, show_debug_info)
+    if use_lite:
+      mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
+          save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
+          upgrade_legacy, show_debug_info)
+      # We don't strictly need this, but it serves as a handy sanity check
+      # for that API, which is otherwise a bit annoying to test.
+      # The canonicalization shouldn't affect these tests in any way.
+      mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
+                                                        'tf-standard-pipeline',
+                                                        show_debug_info)
+    else:
+      mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
+          save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
+          lift_variables, upgrade_legacy, show_debug_info)
 
     if canonicalize:
       mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
index 4cb931253b3..3714c610afd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
@@ -57,8 +57,8 @@ def test():
   # Incur another bound_input on the asset, but with a different sym_name, i.e.,
   # __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt.
   table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10)
-  vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string,
-                                           name='asset_filepath')
+  vocab_file_tensor = tf.convert_to_tensor(
+      vocabulary_file, tf.string, name='asset_filepath')
   tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor)
 
   x = tf.placeholder(tf.string, shape=(), name='input')
@@ -77,4 +77,4 @@ def test():
 
 if __name__ == '__main__':
   common_v1.set_tf_options()
-  common_v1.do_test(test)
+  common_v1.do_test(test, use_lite=True)
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 6b604fd16b1..ddc74fe922a 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -3712,6 +3712,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
                                                  context, upgrade_legacy);
 }
 
+StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
+    const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
+    absl::Span<std::string> exported_names, mlir::MLIRContext* context,
+    bool upgrade_legacy) {
+  return SavedModelSignatureDefImporterLite::Convert(
+      meta_graph_def, debug_info, exported_names, context, upgrade_legacy);
+}
+
 std::string MlirModuleToString(mlir::ModuleOp module,
                                mlir::OpPrintingFlags flags) {
   std::string txt_module;
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h
index 46562848df1..fb7e0311486 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h
@@ -68,6 +68,16 @@ ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
                           mlir::MLIRContext* context,
                           bool upgrade_legacy = false);
 
+// Given a V1 SavedModel, returns a MLIR module containing the functions,
+// expressed with tf_executor dialect. It does not require a session to be
+// created and it does not perform any graph transformation.
+stream_executor::port::StatusOr<mlir::OwningModuleRef>
+ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def,
+                              const GraphDebugInfo& debug_info,
+                              absl::Span<std::string> exported_names,
+                              mlir::MLIRContext* context,
+                              bool upgrade_legacy = false);
+
 // Serialize a MLIR module to a string.
 std::string MlirModuleToString(mlir::ModuleOp module,
                                mlir::OpPrintingFlags flags);
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
index 58377661a23..04bb87fc88f 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "tensorflow/cc/saved_model/bundle_v2.h"
+#include "tensorflow/cc/saved_model/reader.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
@@ -191,6 +192,29 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
   return module_or;
 }
 
+StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
+    absl::string_view saved_model_dir,
+    const std::unordered_set<std::string>& tags,
+    absl::Span<std::string> exported_names, mlir::MLIRContext* context,
+    bool upgrade_legacy) {
+  MetaGraphDef meta_graph_def;
+  auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
+                                               tags, &meta_graph_def);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
+               << "': " << status;
+    return status;
+  }
+
+  auto module_or =
+      ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
+                                    exported_names, context, upgrade_legacy);
+  if (!module_or.status().ok()) {
+    LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
+  }
+  return module_or;
+}
+
 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
     llvm::StringRef input, absl::string_view debug_info_file,
     const std::vector<std::string>& input_arrays,
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
index 0dc49d70192..12d54a747f3 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h
@@ -106,6 +106,16 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
     bool upgrade_legacy = false);
 
+// Converts a TensorFlow V1 SavedModel stored in the directory with the given
+// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
+// given MLIR `context`. This does not create session internally so it is faster
+// and does not perform any graph transformation.
+StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
+    absl::string_view saved_model_dir,
+    const std::unordered_set<std::string>& tags,
+    absl::Span<std::string> exported_names, mlir::MLIRContext* context,
+    bool upgrade_legacy = false);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_
diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
index a60ac4ed222..2871358b733 100644
--- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
@@ -63,6 +63,13 @@ static llvm::cl::opt<bool> import_saved_model_signature_defs(
         "Import a saved model's SignatureDefs to their MLIR representation"),
     llvm::cl::value_desc("dir"));
 
+// NOLINTNEXTLINE
+static llvm::cl::opt<bool> import_saved_model_signature_defs_lite(
+    "savedmodel-signaturedefs-to-mlir-lite",
+    llvm::cl::desc("Import a saved model's SignatureDefs to to their MLIR "
+                   "representation without any graph transformation"),
+    llvm::cl::value_desc("dir"));
+
 // NOLINTNEXTLINE
 static llvm::cl::opt<std::string> saved_model_tags(
     "tf-savedmodel-tags",
@@ -87,11 +94,14 @@ int main(int argc, char** argv) {
   llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
 
   if (!import_saved_model_object_graph && !import_saved_model_signature_defs &&
-      !requested_translation) {
+      !import_saved_model_signature_defs_lite && !requested_translation) {
     llvm::errs() << "error: need to specify one translation to perform\n";
     return 1;
-  } else if (import_saved_model_object_graph &&
-             import_saved_model_signature_defs && requested_translation) {
+  } else if (import_saved_model_object_graph +
+                 import_saved_model_signature_defs +
+                 import_saved_model_signature_defs_lite +
+                 (requested_translation != nullptr) >
+             1) {
     llvm::errs()
         << "error: cannot specify more than one translation to perform\n";
     return 1;
@@ -122,6 +132,13 @@ int main(int argc, char** argv) {
         input_filename, tags, exported_names, &context, upgrade_legacy);
     if (!module_or.status().ok()) return 1;
 
+    module_or.ConsumeValueOrDie()->print(output->os());
+  } else if (import_saved_model_signature_defs_lite) {
+    mlir::MLIRContext context;
+    auto module_or = tensorflow::SavedModelSignatureDefsToMlirImportLite(
+        input_filename, tags, exported_names, &context, upgrade_legacy);
+    if (!module_or.status().ok()) return 1;
+
     module_or.ConsumeValueOrDie()->print(output->os());
   } else {
     auto input = mlir::openInputFile(input_filename, &error_message);
diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc
index fa16e5872ee..7feda116984 100644
--- a/tensorflow/python/mlir_wrapper.cc
+++ b/tensorflow/python/mlir_wrapper.cc
@@ -53,6 +53,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
           return output;
         });
 
+  m.def("ExperimentalConvertSavedModelV1ToMlirLite",
+        [](const std::string &saved_model_path, const std::string &tags,
+           bool upgrade_legacy, bool show_debug_info) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          std::string output =
+              tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
+                  saved_model_path, tags, upgrade_legacy, show_debug_info,
+                  status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+          return output;
+        });
+
   m.def("ExperimentalConvertSavedModelV1ToMlir",
         [](const std::string &saved_model_path, const std::string &tags,
            bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py
index 82048140e16..6db68f0e581 100644
--- a/tensorflow/python/pywrap_mlir.py
+++ b/tensorflow/python/pywrap_mlir.py
@@ -25,8 +25,7 @@ from tensorflow.python._pywrap_mlir import *
 
 def import_graphdef(graphdef, pass_pipeline):
   return ImportGraphDef(
-      str(graphdef).encode('utf-8'),
-      pass_pipeline.encode('utf-8'))
+      str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'))
 
 
 def import_function(concrete_function, pass_pipeline):
@@ -43,6 +42,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
       str(exported_names).encode('utf-8'), show_debug_info)
 
 
+def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
+                                                     upgrade_legacy,
+                                                     show_debug_info):
+  return ExperimentalConvertSavedModelV1ToMlirLite(
+      str(saved_model_path).encode('utf-8'),
+      str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
+
+
 def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
                                                 lift_variables, upgrade_legacy,
                                                 show_debug_info):
@@ -54,5 +61,4 @@ def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
 
 def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
   return ExperimentalRunPassPipeline(
-      mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
-      show_debug_info)
+      mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info)
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index af5b1a104f4..ebe1427ba71 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -222,6 +222,7 @@ tensorflow::EagerContext::WaitForAndCloseRemoteContexts
 
 [mlir] # mlir
 tensorflow::ExperimentalRunPassPipeline
+tensorflow::ExperimentalConvertSavedModelV1ToMlirLite
 tensorflow::ExperimentalConvertSavedModelV1ToMlir
 tensorflow::ExperimentalConvertSavedModelToMlir
 tensorflow::ImportGraphDef