Add the entry point for SavedModelSignatureDefImporterLite in tf-mlir-translate

and relevant python wrappers.

PiperOrigin-RevId: 340945906
Change-Id: I54697b98c18065f829f7f85383512b4c1a460a22
This commit is contained in:
Kuangyuan Chen 2020-11-05 16:02:44 -08:00 committed by TensorFlower Gardener
parent 7540f9ff5a
commit a407b1f41f
14 changed files with 156 additions and 15 deletions

View File

@ -40,6 +40,7 @@ cc_library(
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/core/framework/function.h"
@ -148,6 +149,25 @@ std::string ExperimentalConvertSavedModelToMlir(
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
std::unordered_set<string> tag_set =
absl::StrSplit(tags, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module_or = SavedModelSignatureDefsToMlirImportLite(
saved_model_path, tag_set, /*exported_names=*/{}, &context,
upgrade_legacy);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
}
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info,

View File

@ -55,6 +55,21 @@ std::string ExperimentalConvertSavedModelToMlir(
const std::string &saved_model_path, const std::string &exported_names_str,
bool show_debug_info, TF_Status *status);
// Load a SavedModel V1 and return a textual MLIR string corresponding to it
// without any MLIR graph transformation.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// tags: Tags to identify MetaGraphDef that need to be loaded.
// upgrade_legacy: Boolean flag that indicates whether to upgrade legacy
// graphs
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
// Args:

View File

@ -1509,6 +1509,7 @@ cc_library(
":mangling_util",
":mlir_roundtrip_flags",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:reader",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
@ -1523,7 +1524,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
],
)

View File

@ -46,7 +46,10 @@ def set_tf_options():
# This function needs to take a "create_module_fn", as opposed to just the
# module itself, because the creation of the module has to be delayed until
# after absl and tensorflow have run various initialization steps.
def do_test(create_signature, canonicalize=False, show_debug_info=False):
def do_test(create_signature,
canonicalize=False,
show_debug_info=False,
use_lite=False):
"""Runs test.
1. Performs absl and tf "main"-like initialization that must run before almost
@ -65,6 +68,8 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
MLIR.
canonicalize: If true, canonicalizer will be run on the resulting MLIR.
show_debug_info: If true, shows debug locations in the resulting MLIR.
use_lite: If true, importer will not do any graph transformation such as
lift variables.
"""
# Make LOG(ERROR) in C++ code show up on the console.
@ -99,9 +104,20 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
# variables logic from SavedModel importer is removed.
lift_variables = False
upgrade_legacy = True
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
lift_variables, upgrade_legacy, show_debug_info)
if use_lite:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
upgrade_legacy, show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
'tf-standard-pipeline',
show_debug_info)
else:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
lift_variables, upgrade_legacy, show_debug_info)
if canonicalize:
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',

View File

@ -57,8 +57,8 @@ def test():
# Incur another bound_input on the asset, but with a different sym_name, i.e.,
# __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt.
table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10)
vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string,
name='asset_filepath')
vocab_file_tensor = tf.convert_to_tensor(
vocabulary_file, tf.string, name='asset_filepath')
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor)
x = tf.placeholder(tf.string, shape=(), name='input')
@ -77,4 +77,4 @@ def test():
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(test)
common_v1.do_test(test, use_lite=True)

View File

@ -3712,6 +3712,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
context, upgrade_legacy);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
bool upgrade_legacy) {
return SavedModelSignatureDefImporterLite::Convert(
meta_graph_def, debug_info, exported_names, context, upgrade_legacy);
}
std::string MlirModuleToString(mlir::ModuleOp module,
mlir::OpPrintingFlags flags) {
std::string txt_module;

View File

@ -68,6 +68,16 @@ ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
mlir::MLIRContext* context,
bool upgrade_legacy = false);
// Given a V1 SavedModel, returns a MLIR module containing the functions,
// expressed with tf_executor dialect. It does not require a session to be
// created and it does not perform any graph transformation.
stream_executor::port::StatusOr<mlir::OwningModuleRef>
ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def,
const GraphDebugInfo& debug_info,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context,
bool upgrade_legacy = false);
// Serialize a MLIR module to a string.
std::string MlirModuleToString(mlir::ModuleOp module,
mlir::OpPrintingFlags flags);

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
@ -191,6 +192,29 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
return module_or;
}
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
bool upgrade_legacy) {
MetaGraphDef meta_graph_def;
auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
tags, &meta_graph_def);
if (!status.ok()) {
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
<< "': " << status;
return status;
}
auto module_or =
ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
exported_names, context, upgrade_legacy);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
}
return module_or;
}
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
const std::vector<std::string>& input_arrays,

View File

@ -106,6 +106,16 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
bool upgrade_legacy = false);
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`. This does not create session internally so it is faster
// and does not perform any graph transformation.
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
bool upgrade_legacy = false);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_

View File

@ -63,6 +63,13 @@ static llvm::cl::opt<bool> import_saved_model_signature_defs(
"Import a saved model's SignatureDefs to their MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> import_saved_model_signature_defs_lite(
"savedmodel-signaturedefs-to-mlir-lite",
llvm::cl::desc("Import a saved model's SignatureDefs to to their MLIR "
"representation without any graph transformation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> saved_model_tags(
"tf-savedmodel-tags",
@ -87,11 +94,14 @@ int main(int argc, char** argv) {
llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
if (!import_saved_model_object_graph && !import_saved_model_signature_defs &&
!requested_translation) {
!import_saved_model_signature_defs_lite && !requested_translation) {
llvm::errs() << "error: need to specify one translation to perform\n";
return 1;
} else if (import_saved_model_object_graph &&
import_saved_model_signature_defs && requested_translation) {
} else if (import_saved_model_object_graph +
import_saved_model_signature_defs +
import_saved_model_signature_defs_lite +
(requested_translation != nullptr) >
1) {
llvm::errs()
<< "error: cannot specify more than one translation to perform\n";
return 1;
@ -122,6 +132,13 @@ int main(int argc, char** argv) {
input_filename, tags, exported_names, &context, upgrade_legacy);
if (!module_or.status().ok()) return 1;
module_or.ConsumeValueOrDie()->print(output->os());
} else if (import_saved_model_signature_defs_lite) {
mlir::MLIRContext context;
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImportLite(
input_filename, tags, exported_names, &context, upgrade_legacy);
if (!module_or.status().ok()) return 1;
module_or.ConsumeValueOrDie()->print(output->os());
} else {
auto input = mlir::openInputFile(input_filename, &error_message);

View File

@ -53,6 +53,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
return output;
});
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
[](const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
saved_model_path, tags, upgrade_legacy, show_debug_info,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelV1ToMlir",
[](const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {

View File

@ -25,8 +25,7 @@ from tensorflow.python._pywrap_mlir import *
def import_graphdef(graphdef, pass_pipeline):
return ImportGraphDef(
str(graphdef).encode('utf-8'),
pass_pipeline.encode('utf-8'))
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'))
def import_function(concrete_function, pass_pipeline):
@ -43,6 +42,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
str(exported_names).encode('utf-8'), show_debug_info)
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
upgrade_legacy,
show_debug_info):
return ExperimentalConvertSavedModelV1ToMlirLite(
str(saved_model_path).encode('utf-8'),
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
lift_variables, upgrade_legacy,
show_debug_info):
@ -54,5 +61,4 @@ def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
return ExperimentalRunPassPipeline(
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
show_debug_info)
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info)

View File

@ -222,6 +222,7 @@ tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[mlir] # mlir
tensorflow::ExperimentalRunPassPipeline
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite
tensorflow::ExperimentalConvertSavedModelV1ToMlir
tensorflow::ExperimentalConvertSavedModelToMlir
tensorflow::ImportGraphDef