Refactor GetInitOp() and GetAssetFileDefs() to loader_util.{cc|h} file so that
it can be used by saved model MLIR importer. PiperOrigin-RevId: 316160822 Change-Id: Ic6e1c09715f33f21c3f015c4e90145987f354389
This commit is contained in:
parent
204f6a1c5d
commit
d6a0f20592
@ -106,6 +106,7 @@ cc_library(
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
":loader_util",
|
||||
":reader",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -132,6 +133,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader_util",
|
||||
srcs = ["loader_util.cc"],
|
||||
hdrs = ["loader_util.h"],
|
||||
deps = [":constants"] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bundle_v2_test",
|
||||
srcs = ["bundle_v2_test.cc"],
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -29,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
|
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/loader_util.h
Normal file
39
tensorflow/cc/saved_model/loader_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name);
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
Loading…
Reference in New Issue
Block a user