Add the entry point for SavedModelSignatureDefImporterLite in tf-mlir-translate
and relevant python wrappers. PiperOrigin-RevId: 340945906 Change-Id: I54697b98c18065f829f7f85383512b4c1a460a22
This commit is contained in:
parent
7540f9ff5a
commit
a407b1f41f
@ -40,6 +40,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -148,6 +149,25 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
|
||||
std::unordered_set<string> tag_set =
|
||||
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module_or = SavedModelSignatureDefsToMlirImportLite(
|
||||
saved_model_path, tag_set, /*exported_names=*/{}, &context,
|
||||
upgrade_legacy);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
||||
|
@ -55,6 +55,21 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it
|
||||
// without any MLIR graph transformation.
|
||||
//
|
||||
// Args:
|
||||
// saved_model_path: File path from which to load the SavedModel.
|
||||
// tags: Tags to identify MetaGraphDef that need to be loaded.
|
||||
// upgrade_legacy: Boolean flag that indicates whether to upgrade legacy
|
||||
// graphs
|
||||
//
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
|
@ -1509,6 +1509,7 @@ cc_library(
|
||||
":mangling_util",
|
||||
":mlir_roundtrip_flags",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:reader",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -1523,7 +1524,6 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,10 @@ def set_tf_options():
|
||||
# This function needs to take a "create_module_fn", as opposed to just the
|
||||
# module itself, because the creation of the module has to be delayed until
|
||||
# after absl and tensorflow have run various initialization steps.
|
||||
def do_test(create_signature, canonicalize=False, show_debug_info=False):
|
||||
def do_test(create_signature,
|
||||
canonicalize=False,
|
||||
show_debug_info=False,
|
||||
use_lite=False):
|
||||
"""Runs test.
|
||||
|
||||
1. Performs absl and tf "main"-like initialization that must run before almost
|
||||
@ -65,6 +68,8 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
|
||||
MLIR.
|
||||
canonicalize: If true, canonicalizer will be run on the resulting MLIR.
|
||||
show_debug_info: If true, shows debug locations in the resulting MLIR.
|
||||
use_lite: If true, importer will not do any graph transformation such as
|
||||
lift variables.
|
||||
"""
|
||||
|
||||
# Make LOG(ERROR) in C++ code show up on the console.
|
||||
@ -99,9 +104,20 @@ def do_test(create_signature, canonicalize=False, show_debug_info=False):
|
||||
# variables logic from SavedModel importer is removed.
|
||||
lift_variables = False
|
||||
upgrade_legacy = True
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
lift_variables, upgrade_legacy, show_debug_info)
|
||||
if use_lite:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
upgrade_legacy, show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
# The canonicalization shouldn't affect these tests in any way.
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
|
||||
'tf-standard-pipeline',
|
||||
show_debug_info)
|
||||
else:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
lift_variables, upgrade_legacy, show_debug_info)
|
||||
|
||||
if canonicalize:
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
|
||||
|
@ -57,8 +57,8 @@ def test():
|
||||
# Incur another bound_input on the asset, but with a different sym_name, i.e.,
|
||||
# __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt.
|
||||
table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10)
|
||||
vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string,
|
||||
name='asset_filepath')
|
||||
vocab_file_tensor = tf.convert_to_tensor(
|
||||
vocabulary_file, tf.string, name='asset_filepath')
|
||||
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor)
|
||||
|
||||
x = tf.placeholder(tf.string, shape=(), name='input')
|
||||
@ -77,4 +77,4 @@ def test():
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(test)
|
||||
common_v1.do_test(test, use_lite=True)
|
||||
|
@ -3712,6 +3712,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||
context, upgrade_legacy);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
|
||||
const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
bool upgrade_legacy) {
|
||||
return SavedModelSignatureDefImporterLite::Convert(
|
||||
meta_graph_def, debug_info, exported_names, context, upgrade_legacy);
|
||||
}
|
||||
|
||||
std::string MlirModuleToString(mlir::ModuleOp module,
|
||||
mlir::OpPrintingFlags flags) {
|
||||
std::string txt_module;
|
||||
|
@ -68,6 +68,16 @@ ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||
mlir::MLIRContext* context,
|
||||
bool upgrade_legacy = false);
|
||||
|
||||
// Given a V1 SavedModel, returns a MLIR module containing the functions,
|
||||
// expressed with tf_executor dialect. It does not require a session to be
|
||||
// created and it does not perform any graph transformation.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||
ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context,
|
||||
bool upgrade_legacy = false);
|
||||
|
||||
// Serialize a MLIR module to a string.
|
||||
std::string MlirModuleToString(mlir::ModuleOp module,
|
||||
mlir::OpPrintingFlags flags);
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
@ -191,6 +192,29 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
return module_or;
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
bool upgrade_legacy) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
|
||||
tags, &meta_graph_def);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
|
||||
<< "': " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
auto module_or =
|
||||
ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
|
||||
exported_names, context, upgrade_legacy);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||
}
|
||||
return module_or;
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
const std::vector<std::string>& input_arrays,
|
||||
|
@ -106,6 +106,16 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
bool upgrade_legacy = false);
|
||||
|
||||
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`. This does not create session internally so it is faster
|
||||
// and does not perform any graph transformation.
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
bool upgrade_legacy = false);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_
|
||||
|
@ -63,6 +63,13 @@ static llvm::cl::opt<bool> import_saved_model_signature_defs(
|
||||
"Import a saved model's SignatureDefs to their MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> import_saved_model_signature_defs_lite(
|
||||
"savedmodel-signaturedefs-to-mlir-lite",
|
||||
llvm::cl::desc("Import a saved model's SignatureDefs to to their MLIR "
|
||||
"representation without any graph transformation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<std::string> saved_model_tags(
|
||||
"tf-savedmodel-tags",
|
||||
@ -87,11 +94,14 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
|
||||
|
||||
if (!import_saved_model_object_graph && !import_saved_model_signature_defs &&
|
||||
!requested_translation) {
|
||||
!import_saved_model_signature_defs_lite && !requested_translation) {
|
||||
llvm::errs() << "error: need to specify one translation to perform\n";
|
||||
return 1;
|
||||
} else if (import_saved_model_object_graph &&
|
||||
import_saved_model_signature_defs && requested_translation) {
|
||||
} else if (import_saved_model_object_graph +
|
||||
import_saved_model_signature_defs +
|
||||
import_saved_model_signature_defs_lite +
|
||||
(requested_translation != nullptr) >
|
||||
1) {
|
||||
llvm::errs()
|
||||
<< "error: cannot specify more than one translation to perform\n";
|
||||
return 1;
|
||||
@ -122,6 +132,13 @@ int main(int argc, char** argv) {
|
||||
input_filename, tags, exported_names, &context, upgrade_legacy);
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else if (import_saved_model_signature_defs_lite) {
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImportLite(
|
||||
input_filename, tags, exported_names, &context, upgrade_legacy);
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else {
|
||||
auto input = mlir::openInputFile(input_filename, &error_message);
|
||||
|
@ -53,6 +53,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelV1ToMlirLite",
|
||||
[](const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
saved_model_path, tags, upgrade_legacy, show_debug_info,
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelV1ToMlir",
|
||||
[](const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info) {
|
||||
|
@ -25,8 +25,7 @@ from tensorflow.python._pywrap_mlir import *
|
||||
|
||||
def import_graphdef(graphdef, pass_pipeline):
|
||||
return ImportGraphDef(
|
||||
str(graphdef).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'))
|
||||
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'))
|
||||
|
||||
|
||||
def import_function(concrete_function, pass_pipeline):
|
||||
@ -43,6 +42,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
str(exported_names).encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
|
||||
upgrade_legacy,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
|
||||
lift_variables, upgrade_legacy,
|
||||
show_debug_info):
|
||||
@ -54,5 +61,4 @@ def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
|
||||
|
||||
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||
return ExperimentalRunPassPipeline(
|
||||
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
|
||||
show_debug_info)
|
||||
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info)
|
||||
|
@ -222,6 +222,7 @@ tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
||||
[mlir] # mlir
|
||||
tensorflow::ExperimentalRunPassPipeline
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlirLite
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir
|
||||
tensorflow::ExperimentalConvertSavedModelToMlir
|
||||
tensorflow::ImportGraphDef
|
||||
|
Loading…
Reference in New Issue
Block a user