Export the only one function of a saved model only when it matches with exported_names argument.
PiperOrigin-RevId: 310375397 Change-Id: I93fb94f1c2e2d77e39dc4269206438f48cad0e46
This commit is contained in:
parent
96f4a930db
commit
1d4b4a6706
@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||||
|
|
||||||
|
if (exported_names.size() != 1) {
|
||||||
|
return errors::Unimplemented("Only support a single exported name.");
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto module,
|
TF_ASSIGN_OR_RETURN(auto module,
|
||||||
ImportSavedModel(model_flags.saved_model_dir(),
|
ImportSavedModel(model_flags.saved_model_dir(),
|
||||||
model_flags.saved_model_version(), tags,
|
model_flags.saved_model_version(), tags,
|
||||||
|
@ -160,6 +160,11 @@ int main(int argc, char **argv) {
|
|||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
absl::Span<std::string> exported_names(exported_names_vector);
|
absl::Span<std::string> exported_names(exported_names_vector);
|
||||||
|
|
||||||
|
if (exported_names.size() != 1) {
|
||||||
|
llvm::errs() << "There should be only one exported name";
|
||||||
|
return kTrFailure;
|
||||||
|
}
|
||||||
|
|
||||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||||
tags, exported_names, &context);
|
tags, exported_names, &context);
|
||||||
} else {
|
} else {
|
||||||
|
@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
|||||||
return module;
|
return module;
|
||||||
} else if (saved_model_version == 1) {
|
} else if (saved_model_version == 1) {
|
||||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, context);
|
input_filename, tags, exported_names, context);
|
||||||
|
|
||||||
if (!module)
|
if (!module)
|
||||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||||
|
@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
|||||||
// Convert the SavedModelBundle to an MLIR module.
|
// Convert the SavedModelBundle to an MLIR module.
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
Set_TF_Status_from_Status(status, module_or.status());
|
Set_TF_Status_from_Status(status, module_or.status());
|
||||||
return "// error";
|
return "// error";
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@ -57,6 +58,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||||
@ -65,6 +67,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
|
|||||||
// Main entry point: converts all functions in the given meta graph to an MLIR
|
// Main entry point: converts all functions in the given meta graph to an MLIR
|
||||||
// Module.
|
// Module.
|
||||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes);
|
mlir::MLIRContext* context, bool add_default_attributes);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit SavedModelObjectGraphImporter(
|
explicit SavedModelObjectGraphImporter(
|
||||||
@ -3129,8 +3132,8 @@ Status CreateSavedModelIR(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
mlir::MLIRContext* context, bool add_default_attributes) {
|
||||||
GraphDebugInfo dummy_debug_info;
|
GraphDebugInfo dummy_debug_info;
|
||||||
const GraphDebugInfo& debug_info =
|
const GraphDebugInfo& debug_info =
|
||||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||||
@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter {
|
|||||||
public:
|
public:
|
||||||
// Main entry point: converts all functions (specified by SignatureDefs) in
|
// Main entry point: converts all functions (specified by SignatureDefs) in
|
||||||
// the given meta graph to an MLIR Module.
|
// the given meta graph to an MLIR Module.
|
||||||
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
|
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||||
|
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context) {
|
mlir::MLIRContext* context) {
|
||||||
SavedModelSignatureDefImporter importer(bundle, context);
|
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
|
||||||
|
|
||||||
return importer.ConvertSignatures();
|
return importer.ConvertSignatures();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
||||||
|
absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context)
|
mlir::MLIRContext* context)
|
||||||
: bundle_(bundle),
|
: bundle_(bundle),
|
||||||
|
exported_names_(exported_names),
|
||||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
||||||
|
|
||||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||||
@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter {
|
|||||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||||
|
|
||||||
const SavedModelBundle& bundle_;
|
const SavedModelBundle& bundle_;
|
||||||
|
absl::Span<std::string> exported_names_;
|
||||||
mlir::OwningModuleRef module_;
|
mlir::OwningModuleRef module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
|||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
||||||
|
|
||||||
|
llvm::StringSet<> exported_name_set;
|
||||||
|
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
|
||||||
|
|
||||||
for (const auto& key_and_signature_def : signatures) {
|
for (const auto& key_and_signature_def : signatures) {
|
||||||
const std::string& sig_def_key = key_and_signature_def.first;
|
const std::string& sig_def_key = key_and_signature_def.first;
|
||||||
const SignatureDef& signature_def = key_and_signature_def.second;
|
const SignatureDef& signature_def = key_and_signature_def.second;
|
||||||
@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
|||||||
if (sig_def_key == "__saved_model_init_op") {
|
if (sig_def_key == "__saved_model_init_op") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (!exported_name_set.empty() &&
|
||||||
|
exported_name_set.count(sig_def_key) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
||||||
debug_info, flib_def));
|
debug_info, flib_def));
|
||||||
@ -3556,12 +3570,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||||
return SavedModelObjectGraphImporter::Convert(
|
return SavedModelObjectGraphImporter::Convert(
|
||||||
saved_model, context, exported_names, add_default_attributes);
|
saved_model, exported_names, context, add_default_attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
|
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
|
||||||
return SavedModelSignatureDefImporter::Convert(saved_model, context);
|
mlir::MLIRContext* context) {
|
||||||
|
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
||||||
|
@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
// expressed with tf_executor dialect.
|
// expressed with tf_executor dialect.
|
||||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||||
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||||
|
absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Serialize a MLIR module to a string.
|
// Serialize a MLIR module to a string.
|
||||||
|
@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
|||||||
|
|
||||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||||
absl::string_view saved_model_dir,
|
absl::string_view saved_model_dir,
|
||||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
|
const std::unordered_set<std::string>& tags,
|
||||||
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||||
tensorflow::SavedModelBundle bundle;
|
tensorflow::SavedModelBundle bundle;
|
||||||
tensorflow::SessionOptions session_options;
|
tensorflow::SessionOptions session_options;
|
||||||
// Force saved model states to be restored to CPU.
|
// Force saved model states to be restored to CPU.
|
||||||
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
|||||||
// given MLIR `context`.
|
// given MLIR `context`.
|
||||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||||
absl::string_view saved_model_dir,
|
absl::string_view saved_model_dir,
|
||||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
|
const std::unordered_set<std::string>& tags,
|
||||||
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -104,26 +104,24 @@ int main(int argc, char** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (import_saved_model_object_graph) {
|
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
|
||||||
std::unordered_set<std::string> tags =
|
std::vector<std::string> exported_names_vector =
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
std::vector<std::string> exported_names =
|
|
||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
|
absl::Span<std::string> exported_names(exported_names_vector);
|
||||||
|
|
||||||
|
if (import_saved_model_object_graph) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
input_filename, tags, exported_names, &context);
|
||||||
&context);
|
|
||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
module->print(output->os());
|
module->print(output->os());
|
||||||
} else if (import_saved_model_signature_defs) {
|
} else if (import_saved_model_signature_defs) {
|
||||||
std::unordered_set<std::string> tags =
|
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, &context);
|
input_filename, tags, exported_names, &context);
|
||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
module->print(output->os());
|
module->print(output->os());
|
||||||
|
@ -401,6 +401,7 @@ class TFLiteConverterBase(object):
|
|||||||
if not self._contains_function_with_implements_attr(saved_model_proto):
|
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||||
self.saved_model_dir = None
|
self.saved_model_dir = None
|
||||||
else:
|
else:
|
||||||
|
if not self._saved_model_exported_names:
|
||||||
self._saved_model_exported_names = []
|
self._saved_model_exported_names = []
|
||||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||||
if self._saved_model_version not in [1, 2]:
|
if self._saved_model_version not in [1, 2]:
|
||||||
@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
|||||||
if not signature_keys:
|
if not signature_keys:
|
||||||
signature_keys = saved_model.signatures
|
signature_keys = saved_model.signatures
|
||||||
|
|
||||||
|
if len(signature_keys) != 1:
|
||||||
|
raise ValueError("Only support a single signature key.")
|
||||||
|
|
||||||
funcs = []
|
funcs = []
|
||||||
for key in signature_keys:
|
for key in signature_keys:
|
||||||
if key not in saved_model.signatures:
|
if key not in saved_model.signatures:
|
||||||
|
@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
|||||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
save(root, save_dir, {'add': add_func, 'sub': sub_func})
|
save(root, save_dir, {'add': add_func, 'sub': sub_func})
|
||||||
|
|
||||||
# Ensure the converter generates.
|
|
||||||
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
|
||||||
self.assertLen(converter._funcs, 2)
|
|
||||||
|
|
||||||
# Try converting multiple functions.
|
# Try converting multiple functions.
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
_ = converter.convert()
|
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||||
self.assertIn('This converter can only convert a single ConcreteFunction',
|
self.assertIn('Only support a single signature key.', str(error.exception))
|
||||||
str(error.exception))
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testNoConcreteFunctionModel(self):
|
def testNoConcreteFunctionModel(self):
|
||||||
@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
|||||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
save(root, save_dir)
|
save(root, save_dir)
|
||||||
|
|
||||||
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
|
||||||
self.assertLen(converter._funcs, 0)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
_ = converter.convert()
|
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||||
self.assertIn('No ConcreteFunction is specified.', str(error.exception))
|
self.assertIn('Only support a single signature key.', str(error.exception))
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testKerasSequentialModel(self):
|
def testKerasSequentialModel(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user