Export the only one function of a saved model only when it matches with exported_names argument.
PiperOrigin-RevId: 310375397 Change-Id: I93fb94f1c2e2d77e39dc4269206438f48cad0e46
This commit is contained in:
parent
96f4a930db
commit
1d4b4a6706
@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
|
||||
if (exported_names.size() != 1) {
|
||||
return errors::Unimplemented("Only support a single exported name.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
|
@ -160,6 +160,11 @@ int main(int argc, char **argv) {
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
if (exported_names.size() != 1) {
|
||||
llvm::errs() << "There should be only one exported name";
|
||||
return kTrFailure;
|
||||
}
|
||||
|
||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||
tags, exported_names, &context);
|
||||
} else {
|
||||
|
@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
return module;
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, context);
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@ -57,6 +58,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
@ -65,6 +67,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
|
||||
// Main entry point: converts all functions in the given meta graph to an MLIR
|
||||
// Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes);
|
||||
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool add_default_attributes);
|
||||
|
||||
private:
|
||||
explicit SavedModelObjectGraphImporter(
|
||||
@ -3129,8 +3132,8 @@ Status CreateSavedModelIR(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool add_default_attributes) {
|
||||
GraphDebugInfo dummy_debug_info;
|
||||
const GraphDebugInfo& debug_info =
|
||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||
@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter {
|
||||
public:
|
||||
// Main entry point: converts all functions (specified by SignatureDefs) in
|
||||
// the given meta graph to an MLIR Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
|
||||
mlir::MLIRContext* context) {
|
||||
SavedModelSignatureDefImporter importer(bundle, context);
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context) {
|
||||
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
|
||||
|
||||
return importer.ConvertSignatures();
|
||||
}
|
||||
|
||||
private:
|
||||
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context)
|
||||
: bundle_(bundle),
|
||||
exported_names_(exported_names),
|
||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
||||
|
||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||
@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter {
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||
|
||||
const SavedModelBundle& bundle_;
|
||||
absl::Span<std::string> exported_names_;
|
||||
mlir::OwningModuleRef module_;
|
||||
};
|
||||
|
||||
@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
GraphDebugInfo debug_info;
|
||||
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
||||
|
||||
llvm::StringSet<> exported_name_set;
|
||||
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
|
||||
|
||||
for (const auto& key_and_signature_def : signatures) {
|
||||
const std::string& sig_def_key = key_and_signature_def.first;
|
||||
const SignatureDef& signature_def = key_and_signature_def.second;
|
||||
@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
if (sig_def_key == "__saved_model_init_op") {
|
||||
continue;
|
||||
}
|
||||
if (!exported_name_set.empty() &&
|
||||
exported_name_set.count(sig_def_key) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
||||
debug_info, flib_def));
|
||||
@ -3556,12 +3570,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
return SavedModelObjectGraphImporter::Convert(
|
||||
saved_model, context, exported_names, add_default_attributes);
|
||||
saved_model, exported_names, context, add_default_attributes);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
|
||||
return SavedModelSignatureDefImporter::Convert(saved_model, context);
|
||||
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context) {
|
||||
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
|
||||
context);
|
||||
}
|
||||
|
||||
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
||||
|
@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
// expressed with tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Serialize a MLIR module to a string.
|
||||
|
@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
tensorflow::SessionOptions session_options;
|
||||
// Force saved model states to be restored to CPU.
|
||||
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
|
@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -104,26 +104,24 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
if (import_saved_model_object_graph) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
||||
&context);
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
} else if (import_saved_model_signature_defs) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, &context);
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
|
@ -401,7 +401,8 @@ class TFLiteConverterBase(object):
|
||||
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||
self.saved_model_dir = None
|
||||
else:
|
||||
self._saved_model_exported_names = []
|
||||
if not self._saved_model_exported_names:
|
||||
self._saved_model_exported_names = []
|
||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||
if self._saved_model_version not in [1, 2]:
|
||||
raise ValueError(
|
||||
@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
if not signature_keys:
|
||||
signature_keys = saved_model.signatures
|
||||
|
||||
if len(signature_keys) != 1:
|
||||
raise ValueError("Only support a single signature key.")
|
||||
|
||||
funcs = []
|
||||
for key in signature_keys:
|
||||
if key not in saved_model.signatures:
|
||||
|
@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(root, save_dir, {'add': add_func, 'sub': sub_func})
|
||||
|
||||
# Ensure the converter generates.
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||
self.assertLen(converter._funcs, 2)
|
||||
|
||||
# Try converting multiple functions.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_ = converter.convert()
|
||||
self.assertIn('This converter can only convert a single ConcreteFunction',
|
||||
str(error.exception))
|
||||
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||
self.assertIn('Only support a single signature key.', str(error.exception))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNoConcreteFunctionModel(self):
|
||||
@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(root, save_dir)
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||
self.assertLen(converter._funcs, 0)
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_ = converter.convert()
|
||||
self.assertIn('No ConcreteFunction is specified.', str(error.exception))
|
||||
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||
self.assertIn('Only support a single signature key.', str(error.exception))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testKerasSequentialModel(self):
|
||||
|
Loading…
Reference in New Issue
Block a user