Prevents local copying of CollectionDef map.
- Earlier we were locally copying the map which could be potentially very big. Change: 125595701
This commit is contained in:
parent
2729f5c661
commit
7614900108
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
@ -124,21 +125,23 @@ Status RunInitOp(const StringPiece export_dir,
|
||||
|
||||
tensorflow::Status LoadSessionBundleFromPath(
|
||||
const tensorflow::SessionOptions& options, const StringPiece export_dir,
|
||||
SessionBundle* bundle) {
|
||||
SessionBundle* const bundle) {
|
||||
LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
|
||||
|
||||
auto collection_def = bundle->meta_graph_def.collection_def();
|
||||
if (collection_def.find(kGraphKey) != collection_def.end()) {
|
||||
const auto& collection_def_map = bundle->meta_graph_def.collection_def();
|
||||
const auto graph_it = bundle->meta_graph_def.collection_def().find(kGraphKey);
|
||||
if (graph_it != collection_def_map.end()) {
|
||||
const CollectionDef& graph_collection_def = graph_it->second;
|
||||
// Use serving graph_def in MetaGraphDef collection_def.
|
||||
if (collection_def[kGraphKey].any_list().value_size() != 1) {
|
||||
if (graph_collection_def.any_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one serving GraphDef in : ",
|
||||
bundle->meta_graph_def.DebugString()));
|
||||
}
|
||||
tensorflow::GraphDef graph_def;
|
||||
collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def);
|
||||
graph_collection_def.any_list().value(0).UnpackTo(&graph_def);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateSessionFromGraphDef(options, graph_def, &bundle->session));
|
||||
} else {
|
||||
@ -149,11 +152,14 @@ tensorflow::Status LoadSessionBundleFromPath(
|
||||
}
|
||||
|
||||
std::vector<AssetFile> asset_files;
|
||||
auto any_assets = collection_def[kAssetsKey].any_list().value();
|
||||
for (const auto any_asset : any_assets) {
|
||||
AssetFile asset_file;
|
||||
any_asset.UnpackTo(&asset_file);
|
||||
asset_files.push_back(asset_file);
|
||||
const auto assets_it = collection_def_map.find(kAssetsKey);
|
||||
if (assets_it != collection_def_map.end()) {
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFile asset_file;
|
||||
any_asset.UnpackTo(&asset_file);
|
||||
asset_files.push_back(asset_file);
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -162,14 +168,15 @@ tensorflow::Status LoadSessionBundleFromPath(
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
bundle->session.get()));
|
||||
|
||||
if (collection_def.find(kInitOpKey) != collection_def.end()) {
|
||||
if (collection_def[kInitOpKey].node_list().value_size() != 1) {
|
||||
const auto init_op_it = collection_def_map.find(kInitOpKey);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one serving init op in : ",
|
||||
bundle->meta_graph_def.DebugString()));
|
||||
}
|
||||
return RunInitOp(export_dir, asset_files,
|
||||
collection_def[kInitOpKey].node_list().value(0),
|
||||
init_op_it->second.node_list().value(0),
|
||||
bundle->session.get());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user