From a407b1f41f8f4afab54b6ff745da04e5ac376681 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen <chky@google.com> Date: Thu, 5 Nov 2020 16:02:44 -0800 Subject: [PATCH] Add the entry point for SavedModelSignatureDefImporterLite in tf-mlir-translate and relevant python wrappers. PiperOrigin-RevId: 340945906 Change-Id: I54697b98c18065f829f7f85383512b4c1a460a22 --- tensorflow/compiler/mlir/python/BUILD | 1 + tensorflow/compiler/mlir/python/mlir.cc | 20 ++++++++++++++++ tensorflow/compiler/mlir/python/mlir.h | 15 ++++++++++++ tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../tests/tf_saved_model/common_v1.py | 24 +++++++++++++++---- .../tf_saved_model/hash_table_asset_v1.py | 6 ++--- .../mlir/tensorflow/translate/import_model.cc | 8 +++++++ .../mlir/tensorflow/translate/import_model.h | 10 ++++++++ .../tensorflow/translate/tf_mlir_translate.cc | 24 +++++++++++++++++++ .../tensorflow/translate/tf_mlir_translate.h | 10 ++++++++ .../compiler/mlir/tf_mlir_translate_main.cc | 23 +++++++++++++++--- tensorflow/python/mlir_wrapper.cc | 13 ++++++++++ tensorflow/python/pywrap_mlir.py | 14 +++++++---- .../tools/def_file_filter/symbols_pybind.txt | 1 + 14 files changed, 156 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 502695acd40..0bc2acf7e4a 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -40,6 +40,7 @@ cc_library( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 066726593a7..94fb78b3bd3 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" #include "tensorflow/core/framework/function.h" @@ -148,6 +149,25 @@ std::string ExperimentalConvertSavedModelToMlir( return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); } +std::string ExperimentalConvertSavedModelV1ToMlirLite( + const std::string &saved_model_path, const std::string &tags, + bool upgrade_legacy, bool show_debug_info, TF_Status *status) { + std::unordered_set<string> tag_set = + absl::StrSplit(tags, ',', absl::SkipEmpty()); + + mlir::MLIRContext context; + + auto module_or = SavedModelSignatureDefsToMlirImportLite( + saved_model_path, tag_set, /*exported_names=*/{}, &context, + upgrade_legacy); + if (!module_or.status().ok()) { + Set_TF_Status_from_Status(status, module_or.status()); + return "// error"; + } + + return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info); +} + std::string ExperimentalConvertSavedModelV1ToMlir( const std::string &saved_model_path, const std::string &tags, bool lift_variables, bool upgrade_legacy, bool show_debug_info, diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index 6133068a5e8..37560d04bf8 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -55,6 +55,21 @@ std::string ExperimentalConvertSavedModelToMlir( const std::string &saved_model_path, const std::string &exported_names_str, bool show_debug_info, TF_Status *status); +// Load a SavedModel V1 and return a textual MLIR string corresponding to it +// without any MLIR graph transformation. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// tags: Tags to identify MetaGraphDef that need to be loaded. +// upgrade_legacy: Boolean flag that indicates whether to upgrade legacy +// graphs +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +std::string ExperimentalConvertSavedModelV1ToMlirLite( + const std::string &saved_model_path, const std::string &tags, + bool upgrade_legacy, bool show_debug_info, TF_Status *status); + // Load a SavedModel V1 and return a textual MLIR string corresponding to it. // // Args: diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d02f0373c79..443c651a8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1509,6 +1509,7 @@ cc_library( ":mangling_util", ":mlir_roundtrip_flags", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:reader", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -1523,7 +1524,6 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index 7a61b4b4f6a..68cb3258b82 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -46,7 +46,10 @@ def set_tf_options(): # This function needs to take a "create_module_fn", as opposed to just the # module itself, because the creation of the module has to be delayed until # after absl and tensorflow have run various initialization steps. -def do_test(create_signature, canonicalize=False, show_debug_info=False): +def do_test(create_signature, + canonicalize=False, + show_debug_info=False, + use_lite=False): """Runs test. 1. Performs absl and tf "main"-like initialization that must run before almost @@ -65,6 +68,8 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False): MLIR. canonicalize: If true, canonicalizer will be run on the resulting MLIR. show_debug_info: If true, shows debug locations in the resulting MLIR. + use_lite: If true, importer will not do any graph transformation such as + lift variables. """ # Make LOG(ERROR) in C++ code show up on the console. @@ -99,9 +104,20 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False): # variables logic from SavedModel importer is removed. lift_variables = False upgrade_legacy = True - mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( - save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), - lift_variables, upgrade_legacy, show_debug_info) + if use_lite: + mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite( + save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + upgrade_legacy, show_debug_info) + # We don't strictly need this, but it serves as a handy sanity check + # for that API, which is otherwise a bit annoying to test. + # The canonicalization shouldn't affect these tests in any way. + mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, + 'tf-standard-pipeline', + show_debug_info) + else: + mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( + save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + lift_variables, upgrade_legacy, show_debug_info) if canonicalize: mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize', diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py index 4cb931253b3..3714c610afd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py @@ -57,8 +57,8 @@ def test(): # Incur another bound_input on the asset, but with a different sym_name, i.e., # __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt. table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10) - vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string, - name='asset_filepath') + vocab_file_tensor = tf.convert_to_tensor( + vocabulary_file, tf.string, name='asset_filepath') tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor) x = tf.placeholder(tf.string, shape=(), name='input') @@ -77,4 +77,4 @@ def test(): if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(test) + common_v1.do_test(test, use_lite=True) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 6b604fd16b1..ddc74fe922a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3712,6 +3712,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir( context, upgrade_legacy); } +StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite( + const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, + absl::Span<std::string> exported_names, mlir::MLIRContext* context, + bool upgrade_legacy) { + return SavedModelSignatureDefImporterLite::Convert( + meta_graph_def, debug_info, exported_names, context, upgrade_legacy); +} + std::string MlirModuleToString(mlir::ModuleOp module, mlir::OpPrintingFlags flags) { std::string txt_module; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 46562848df1..fb7e0311486 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -68,6 +68,16 @@ ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, mlir::MLIRContext* context, bool upgrade_legacy = false); +// Given a V1 SavedModel, returns a MLIR module containing the functions, +// expressed with tf_executor dialect. It does not require a session to be +// created and it does not perform any graph transformation. +stream_executor::port::StatusOr<mlir::OwningModuleRef> +ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def, + const GraphDebugInfo& debug_info, + absl::Span<std::string> exported_names, + mlir::MLIRContext* context, + bool upgrade_legacy = false); + // Serialize a MLIR module to a string. std::string MlirModuleToString(mlir::ModuleOp module, mlir::OpPrintingFlags flags); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 58377661a23..04bb87fc88f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -191,6 +192,29 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport( return module_or; } +StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite( + absl::string_view saved_model_dir, + const std::unordered_set<std::string>& tags, + absl::Span<std::string> exported_names, mlir::MLIRContext* context, + bool upgrade_legacy) { + MetaGraphDef meta_graph_def; + auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir), + tags, &meta_graph_def); + if (!status.ok()) { + LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir + << "': " << status; + return status; + } + + auto module_or = + ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{}, + exported_names, context, upgrade_legacy); + if (!module_or.status().ok()) { + LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); + } + return module_or; +} + StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, const std::vector<std::string>& input_arrays, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 0dc49d70192..12d54a747f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -106,6 +106,16 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport( absl::Span<std::string> exported_names, mlir::MLIRContext* context, bool upgrade_legacy = false); +// Converts a TensorFlow V1 SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. This does not create session internally so it is faster +// and does not perform any graph transformation. +StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite( + absl::string_view saved_model_dir, + const std::unordered_set<std::string>& tags, + absl::Span<std::string> exported_names, mlir::MLIRContext* context, + bool upgrade_legacy = false); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index a60ac4ed222..2871358b733 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -63,6 +63,13 @@ static llvm::cl::opt<bool> import_saved_model_signature_defs( "Import a saved model's SignatureDefs to their MLIR representation"), llvm::cl::value_desc("dir")); +// NOLINTNEXTLINE +static llvm::cl::opt<bool> import_saved_model_signature_defs_lite( + "savedmodel-signaturedefs-to-mlir-lite", + llvm::cl::desc("Import a saved model's SignatureDefs to to their MLIR " + "representation without any graph transformation"), + llvm::cl::value_desc("dir")); + // NOLINTNEXTLINE static llvm::cl::opt<std::string> saved_model_tags( "tf-savedmodel-tags", @@ -87,11 +94,14 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); if (!import_saved_model_object_graph && !import_saved_model_signature_defs && - !requested_translation) { + !import_saved_model_signature_defs_lite && !requested_translation) { llvm::errs() << "error: need to specify one translation to perform\n"; return 1; - } else if (import_saved_model_object_graph && - import_saved_model_signature_defs && requested_translation) { + } else if (import_saved_model_object_graph + + import_saved_model_signature_defs + + import_saved_model_signature_defs_lite + + (requested_translation != nullptr) > + 1) { llvm::errs() << "error: cannot specify more than one translation to perform\n"; return 1; @@ -122,6 +132,13 @@ int main(int argc, char** argv) { input_filename, tags, exported_names, &context, upgrade_legacy); if (!module_or.status().ok()) return 1; + module_or.ConsumeValueOrDie()->print(output->os()); + } else if (import_saved_model_signature_defs_lite) { + mlir::MLIRContext context; + auto module_or = tensorflow::SavedModelSignatureDefsToMlirImportLite( + input_filename, tags, exported_names, &context, upgrade_legacy); + if (!module_or.status().ok()) return 1; + module_or.ConsumeValueOrDie()->print(output->os()); } else { auto input = mlir::openInputFile(input_filename, &error_message); diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index fa16e5872ee..7feda116984 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -53,6 +53,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) { return output; }); + m.def("ExperimentalConvertSavedModelV1ToMlirLite", + [](const std::string &saved_model_path, const std::string &tags, + bool upgrade_legacy, bool show_debug_info) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + std::string output = + tensorflow::ExperimentalConvertSavedModelV1ToMlirLite( + saved_model_path, tags, upgrade_legacy, show_debug_info, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + m.def("ExperimentalConvertSavedModelV1ToMlir", [](const std::string &saved_model_path, const std::string &tags, bool lift_variables, bool upgrade_legacy, bool show_debug_info) { diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py index 82048140e16..6db68f0e581 100644 --- a/tensorflow/python/pywrap_mlir.py +++ b/tensorflow/python/pywrap_mlir.py @@ -25,8 +25,7 @@ from tensorflow.python._pywrap_mlir import * def import_graphdef(graphdef, pass_pipeline): return ImportGraphDef( - str(graphdef).encode('utf-8'), - pass_pipeline.encode('utf-8')) + str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8')) def import_function(concrete_function, pass_pipeline): @@ -43,6 +42,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names, str(exported_names).encode('utf-8'), show_debug_info) +def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags, + upgrade_legacy, + show_debug_info): + return ExperimentalConvertSavedModelV1ToMlirLite( + str(saved_model_path).encode('utf-8'), + str(tags).encode('utf-8'), upgrade_legacy, show_debug_info) + + def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags, lift_variables, upgrade_legacy, show_debug_info): @@ -54,5 +61,4 @@ def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags, def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): return ExperimentalRunPassPipeline( - mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), - show_debug_info) + mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info) diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index af5b1a104f4..ebe1427ba71 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -222,6 +222,7 @@ tensorflow::EagerContext::WaitForAndCloseRemoteContexts [mlir] # mlir tensorflow::ExperimentalRunPassPipeline +tensorflow::ExperimentalConvertSavedModelV1ToMlirLite tensorflow::ExperimentalConvertSavedModelV1ToMlir tensorflow::ExperimentalConvertSavedModelToMlir tensorflow::ImportGraphDef