From 1d4b4a6706b263377ccab18d94161c9ef6ca0133 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Thu, 7 May 2020 09:26:23 -0700 Subject: [PATCH] Export the only one function of a saved model only when it matches with exported_names argument. PiperOrigin-RevId: 310375397 Change-Id: I93fb94f1c2e2d77e39dc4269206438f48cad0e46 --- .../python/saved_model_to_tfl_flatbuffer.cc | 4 +++ .../compiler/mlir/lite/tf_tfl_translate.cc | 5 +++ .../mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- tensorflow/compiler/mlir/python/mlir.cc | 2 +- .../mlir/tensorflow/translate/import_model.cc | 36 +++++++++++++------ .../mlir/tensorflow/translate/import_model.h | 1 + .../tensorflow/translate/tf_mlir_translate.cc | 5 +-- .../tensorflow/translate/tf_mlir_translate.h | 3 +- .../compiler/mlir/tf_mlir_translate_main.cc | 16 ++++----- tensorflow/lite/python/lite.py | 6 +++- tensorflow/lite/python/lite_v2_test.py | 16 +++------ 11 files changed, 59 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index c338b723a4a..51fcbb97360 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer( saved_model_exported_names.begin(), saved_model_exported_names.end()); absl::Span exported_names(exported_names_in_vector); + if (exported_names.size() != 1) { + return errors::Unimplemented("Only support a single exported name."); + } + TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 4bc9d9e0c2d..fce1333a491 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -160,6 +160,11 @@ int main(int argc, char **argv) { absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); + if (exported_names.size() != 1) { + llvm::errs() << "There should be only one exported name"; + return kTrFailure; + } + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, exported_names, &context); } else { diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b9ec67736d9..62f64ab63b4 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -174,7 +174,7 @@ StatusOr ImportSavedModel( return module; } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index d0f6e015922..f22fb519a64 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 49be3da912a..3bb1446213b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -57,6 +58,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" @@ -65,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes); + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes); private: explicit SavedModelObjectGraphImporter( @@ -3129,8 +3132,8 @@ Status CreateSavedModelIR( } StatusOr SavedModelObjectGraphImporter::Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes) { + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter { public: // Main entry point: converts all functions (specified by SignatureDefs) in // the given meta graph to an MLIR Module. - static StatusOr Convert(const SavedModelBundle& bundle, - mlir::MLIRContext* context) { - SavedModelSignatureDefImporter importer(bundle, context); + static StatusOr Convert( + const SavedModelBundle& bundle, absl::Span exported_names, + mlir::MLIRContext* context) { + SavedModelSignatureDefImporter importer(bundle, exported_names, context); return importer.ConvertSignatures(); } private: SavedModelSignatureDefImporter(const SavedModelBundle& bundle, + absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + exported_names_(exported_names), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function @@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter { const std::vector>& inputs); const SavedModelBundle& bundle_; + absl::Span exported_names_; mlir::OwningModuleRef module_; }; @@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() { GraphDebugInfo debug_info; if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + llvm::StringSet<> exported_name_set; + exported_name_set.insert(exported_names_.begin(), exported_names_.end()); + for (const auto& key_and_signature_def : signatures) { const std::string& sig_def_key = key_and_signature_def.first; const SignatureDef& signature_def = key_and_signature_def.second; @@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() { if (sig_def_key == "__saved_model_init_op") { continue; } + if (!exported_name_set.empty() && + exported_name_set.count(sig_def_key) == 0) { + continue; + } TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, debug_info, flib_def)); @@ -3556,12 +3570,14 @@ StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { return SavedModelObjectGraphImporter::Convert( - saved_model, context, exported_names, add_default_attributes); + saved_model, exported_names, context, add_default_attributes); } StatusOr ConvertSavedModelV1ToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context) { - return SavedModelSignatureDefImporter::Convert(saved_model, context); + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context) { + return SavedModelSignatureDefImporter::Convert(saved_model, exported_names, + context); } std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 8603eadb487..bdb72345201 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -55,6 +55,7 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( // expressed with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + absl::Span exported_names, mlir::MLIRContext* context); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 2c7f84d8268..6ada0fec4e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context) { + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; tensorflow::SessionOptions session_options; // Force saved model states to be restored to CPU. @@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( return nullptr; } - auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index f498864c8aa..490b7c7d8f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( // given MLIR `context`. mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context); + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 62b862f5e21..2e1528e0d60 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -104,26 +104,24 @@ int main(int argc, char** argv) { return 1; } + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_vector); + if (import_saved_model_object_graph) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - std::vector exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); mlir::MLIRContext context; auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span(exported_names), - &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); } else if (import_saved_model_signature_defs) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); mlir::MLIRContext context; auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index b2d58ec8746..61daa699f5a 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -401,7 +401,8 @@ class TFLiteConverterBase(object): if not self._contains_function_with_implements_attr(saved_model_proto): self.saved_model_dir = None else: - self._saved_model_exported_names = [] + if not self._saved_model_exported_names: + self._saved_model_exported_names = [] self._saved_model_version = saved_model_proto.saved_model_schema_version if self._saved_model_version not in [1, 2]: raise ValueError( @@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2): if not signature_keys: signature_keys = saved_model.signatures + if len(signature_keys) != 1: + raise ValueError("Only support a single signature key.") + funcs = [] for key in signature_keys: if key not in saved_model.signatures: diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 59f326d4b9f..5470e332b3d 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save(root, save_dir, {'add': add_func, 'sub': sub_func}) - # Ensure the converter generates. - converter = lite.TFLiteConverterV2.from_saved_model(save_dir) - self.assertLen(converter._funcs, 2) - # Try converting multiple functions. with self.assertRaises(ValueError) as error: - _ = converter.convert() - self.assertIn('This converter can only convert a single ConcreteFunction', - str(error.exception)) + _ = lite.TFLiteConverterV2.from_saved_model(save_dir) + self.assertIn('Only support a single signature key.', str(error.exception)) @test_util.run_v2_only def testNoConcreteFunctionModel(self): @@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save(root, save_dir) - converter = lite.TFLiteConverterV2.from_saved_model(save_dir) - self.assertLen(converter._funcs, 0) - with self.assertRaises(ValueError) as error: - _ = converter.convert() - self.assertIn('No ConcreteFunction is specified.', str(error.exception)) + _ = lite.TFLiteConverterV2.from_saved_model(save_dir) + self.assertIn('Only support a single signature key.', str(error.exception)) @test_util.run_v2_only def testKerasSequentialModel(self):