Export the only one function of a saved model only when it matches with exported_names argument.

PiperOrigin-RevId: 310375397
Change-Id: I93fb94f1c2e2d77e39dc4269206438f48cad0e46
This commit is contained in:
Jaesung Chung 2020-05-07 09:26:23 -07:00 committed by TensorFlower Gardener
parent 96f4a930db
commit 1d4b4a6706
11 changed files with 59 additions and 37 deletions

View File

@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
saved_model_exported_names.begin(), saved_model_exported_names.end()); saved_model_exported_names.begin(), saved_model_exported_names.end());
absl::Span<std::string> exported_names(exported_names_in_vector); absl::Span<std::string> exported_names(exported_names_in_vector);
if (exported_names.size() != 1) {
return errors::Unimplemented("Only support a single exported name.");
}
TF_ASSIGN_OR_RETURN(auto module, TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(), ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags, model_flags.saved_model_version(), tags,

View File

@ -160,6 +160,11 @@ int main(int argc, char **argv) {
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector); absl::Span<std::string> exported_names(exported_names_vector);
if (exported_names.size() != 1) {
llvm::errs() << "There should be only one exported name";
return kTrFailure;
}
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context); tags, exported_names, &context);
} else { } else {

View File

@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
return module; return module;
} else if (saved_model_version == 1) { } else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport( auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, context); input_filename, tags, exported_names, context);
if (!module) if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file"); return tensorflow::errors::InvalidArgument("fail to open input file");

View File

@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
// Convert the SavedModelBundle to an MLIR module. // Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context; mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
if (!module_or.status().ok()) { if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status()); Set_TF_Status_from_Status(status, module_or.status());
return "// error"; return "// error";

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@ -57,6 +58,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
@ -65,6 +67,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
// Main entry point: converts all functions in the given meta graph to an MLIR // Main entry point: converts all functions in the given meta graph to an MLIR
// Module. // Module.
static StatusOr<mlir::OwningModuleRef> Convert( static StatusOr<mlir::OwningModuleRef> Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
absl::Span<std::string> exported_names, bool add_default_attributes); mlir::MLIRContext* context, bool add_default_attributes);
private: private:
explicit SavedModelObjectGraphImporter( explicit SavedModelObjectGraphImporter(
@ -3129,8 +3132,8 @@ Status CreateSavedModelIR(
} }
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert( StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
absl::Span<std::string> exported_names, bool add_default_attributes) { mlir::MLIRContext* context, bool add_default_attributes) {
GraphDebugInfo dummy_debug_info; GraphDebugInfo dummy_debug_info;
const GraphDebugInfo& debug_info = const GraphDebugInfo& debug_info =
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter {
public: public:
// Main entry point: converts all functions (specified by SignatureDefs) in // Main entry point: converts all functions (specified by SignatureDefs) in
// the given meta graph to an MLIR Module. // the given meta graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle, static StatusOr<mlir::OwningModuleRef> Convert(
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
mlir::MLIRContext* context) { mlir::MLIRContext* context) {
SavedModelSignatureDefImporter importer(bundle, context); SavedModelSignatureDefImporter importer(bundle, exported_names, context);
return importer.ConvertSignatures(); return importer.ConvertSignatures();
} }
private: private:
SavedModelSignatureDefImporter(const SavedModelBundle& bundle, SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context) mlir::MLIRContext* context)
: bundle_(bundle), : bundle_(bundle),
exported_names_(exported_names),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter {
const std::vector<std::pair<std::string, TensorInfo>>& inputs); const std::vector<std::pair<std::string, TensorInfo>>& inputs);
const SavedModelBundle& bundle_; const SavedModelBundle& bundle_;
absl::Span<std::string> exported_names_;
mlir::OwningModuleRef module_; mlir::OwningModuleRef module_;
}; };
@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
llvm::StringSet<> exported_name_set;
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
for (const auto& key_and_signature_def : signatures) { for (const auto& key_and_signature_def : signatures) {
const std::string& sig_def_key = key_and_signature_def.first; const std::string& sig_def_key = key_and_signature_def.first;
const SignatureDef& signature_def = key_and_signature_def.second; const SignatureDef& signature_def = key_and_signature_def.second;
@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
if (sig_def_key == "__saved_model_init_op") { if (sig_def_key == "__saved_model_init_op") {
continue; continue;
} }
if (!exported_name_set.empty() &&
exported_name_set.count(sig_def_key) == 0) {
continue;
}
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
debug_info, flib_def)); debug_info, flib_def));
@ -3556,12 +3570,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) { absl::Span<std::string> exported_names, bool add_default_attributes) {
return SavedModelObjectGraphImporter::Convert( return SavedModelObjectGraphImporter::Convert(
saved_model, context, exported_names, add_default_attributes); saved_model, exported_names, context, add_default_attributes);
} }
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir( StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, mlir::MLIRContext* context) { const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
return SavedModelSignatureDefImporter::Convert(saved_model, context); mlir::MLIRContext* context) {
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
context);
} }
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {

View File

@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
// expressed with tf_executor dialect. // expressed with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> stream_executor::port::StatusOr<mlir::OwningModuleRef>
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context); mlir::MLIRContext* context);
// Serialize a MLIR module to a string. // Serialize a MLIR module to a string.

View File

@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir, absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) { const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
tensorflow::SavedModelBundle bundle; tensorflow::SavedModelBundle bundle;
tensorflow::SessionOptions session_options; tensorflow::SessionOptions session_options;
// Force saved model states to be restored to CPU. // Force saved model states to be restored to CPU.
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
return nullptr; return nullptr;
} }
auto module_or = ConvertSavedModelV1ToMlir(bundle, context); auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
if (!module_or.status().ok()) { if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr; return nullptr;

View File

@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
// given MLIR `context`. // given MLIR `context`.
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir, absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context); const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
} // namespace tensorflow } // namespace tensorflow

View File

@ -104,26 +104,24 @@ int main(int argc, char** argv) {
return 1; return 1;
} }
if (import_saved_model_object_graph) { std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
std::unordered_set<std::string> tags = std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (import_saved_model_object_graph) {
mlir::MLIRContext context; mlir::MLIRContext context;
auto module = tensorflow::SavedModelObjectGraphToMlirImport( auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), input_filename, tags, exported_names, &context);
&context);
if (!module) return 1; if (!module) return 1;
module->print(output->os()); module->print(output->os());
} else if (import_saved_model_signature_defs) { } else if (import_saved_model_signature_defs) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
mlir::MLIRContext context; mlir::MLIRContext context;
auto module = tensorflow::SavedModelSignatureDefsToMlirImport( auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, &context); input_filename, tags, exported_names, &context);
if (!module) return 1; if (!module) return 1;
module->print(output->os()); module->print(output->os());

View File

@ -401,6 +401,7 @@ class TFLiteConverterBase(object):
if not self._contains_function_with_implements_attr(saved_model_proto): if not self._contains_function_with_implements_attr(saved_model_proto):
self.saved_model_dir = None self.saved_model_dir = None
else: else:
if not self._saved_model_exported_names:
self._saved_model_exported_names = [] self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version not in [1, 2]: if self._saved_model_version not in [1, 2]:
@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
if not signature_keys: if not signature_keys:
signature_keys = saved_model.signatures signature_keys = saved_model.signatures
if len(signature_keys) != 1:
raise ValueError("Only support a single signature key.")
funcs = [] funcs = []
for key in signature_keys: for key in signature_keys:
if key not in saved_model.signatures: if key not in saved_model.signatures:

View File

@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'add': add_func, 'sub': sub_func}) save(root, save_dir, {'add': add_func, 'sub': sub_func})
# Ensure the converter generates.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 2)
# Try converting multiple functions. # Try converting multiple functions.
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
_ = converter.convert() _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('This converter can only convert a single ConcreteFunction', self.assertIn('Only support a single signature key.', str(error.exception))
str(error.exception))
@test_util.run_v2_only @test_util.run_v2_only
def testNoConcreteFunctionModel(self): def testNoConcreteFunctionModel(self):
@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir) save(root, save_dir)
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 0)
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
_ = converter.convert() _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('No ConcreteFunction is specified.', str(error.exception)) self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only @test_util.run_v2_only
def testKerasSequentialModel(self): def testKerasSequentialModel(self):