Export the only one function of a saved model only when it matches with exported_names argument.

PiperOrigin-RevId: 310375397
Change-Id: I93fb94f1c2e2d77e39dc4269206438f48cad0e46
This commit is contained in:
Jaesung Chung 2020-05-07 09:26:23 -07:00 committed by TensorFlower Gardener
parent 96f4a930db
commit 1d4b4a6706
11 changed files with 59 additions and 37 deletions

View File

@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
saved_model_exported_names.begin(), saved_model_exported_names.end());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (exported_names.size() != 1) {
return errors::Unimplemented("Only support a single exported name.");
}
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,

View File

@ -160,6 +160,11 @@ int main(int argc, char **argv) {
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (exported_names.size() != 1) {
llvm::errs() << "There should be only one exported name";
return kTrFailure;
}
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else {

View File

@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
return module;
} else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, context);
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");

View File

@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@ -57,6 +58,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
@ -65,6 +67,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
// Main entry point: converts all functions in the given meta graph to an MLIR
// Module.
static StatusOr<mlir::OwningModuleRef> Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes);
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool add_default_attributes);
private:
explicit SavedModelObjectGraphImporter(
@ -3129,8 +3132,8 @@ Status CreateSavedModelIR(
}
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool add_default_attributes) {
GraphDebugInfo dummy_debug_info;
const GraphDebugInfo& debug_info =
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter {
public:
// Main entry point: converts all functions (specified by SignatureDefs) in
// the given meta graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
mlir::MLIRContext* context) {
SavedModelSignatureDefImporter importer(bundle, context);
static StatusOr<mlir::OwningModuleRef> Convert(
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
mlir::MLIRContext* context) {
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
return importer.ConvertSignatures();
}
private:
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context)
: bundle_(bundle),
exported_names_(exported_names),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter {
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
const SavedModelBundle& bundle_;
absl::Span<std::string> exported_names_;
mlir::OwningModuleRef module_;
};
@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
GraphDebugInfo debug_info;
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
llvm::StringSet<> exported_name_set;
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
for (const auto& key_and_signature_def : signatures) {
const std::string& sig_def_key = key_and_signature_def.first;
const SignatureDef& signature_def = key_and_signature_def.second;
@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
if (sig_def_key == "__saved_model_init_op") {
continue;
}
if (!exported_name_set.empty() &&
exported_name_set.count(sig_def_key) == 0) {
continue;
}
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
debug_info, flib_def));
@ -3556,12 +3570,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {
return SavedModelObjectGraphImporter::Convert(
saved_model, context, exported_names, add_default_attributes);
saved_model, exported_names, context, add_default_attributes);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
return SavedModelSignatureDefImporter::Convert(saved_model, context);
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context) {
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
context);
}
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {

View File

@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
// expressed with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef>
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context);
// Serialize a MLIR module to a string.

View File

@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
tensorflow::SavedModelBundle bundle;
tensorflow::SessionOptions session_options;
// Force saved model states to be restored to CPU.
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
return nullptr;
}
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr;

View File

@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
// given MLIR `context`.
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
} // namespace tensorflow

View File

@ -104,26 +104,24 @@ int main(int argc, char** argv) {
return 1;
}
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (import_saved_model_object_graph) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names),
&context);
input_filename, tags, exported_names, &context);
if (!module) return 1;
module->print(output->os());
} else if (import_saved_model_signature_defs) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
mlir::MLIRContext context;
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, &context);
input_filename, tags, exported_names, &context);
if (!module) return 1;
module->print(output->os());

View File

@ -401,7 +401,8 @@ class TFLiteConverterBase(object):
if not self._contains_function_with_implements_attr(saved_model_proto):
self.saved_model_dir = None
else:
self._saved_model_exported_names = []
if not self._saved_model_exported_names:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version not in [1, 2]:
raise ValueError(
@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
if not signature_keys:
signature_keys = saved_model.signatures
if len(signature_keys) != 1:
raise ValueError("Only support a single signature key.")
funcs = []
for key in signature_keys:
if key not in saved_model.signatures:

View File

@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'add': add_func, 'sub': sub_func})
# Ensure the converter generates.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 2)
# Try converting multiple functions.
with self.assertRaises(ValueError) as error:
_ = converter.convert()
self.assertIn('This converter can only convert a single ConcreteFunction',
str(error.exception))
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testNoConcreteFunctionModel(self):
@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir)
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 0)
with self.assertRaises(ValueError) as error:
_ = converter.convert()
self.assertIn('No ConcreteFunction is specified.', str(error.exception))
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testKerasSequentialModel(self):