Prevents local copying of CollectionDef map.

- Earlier we were locally copying the map which could be potentially very big.
Change: 125595701
This commit is contained in:
Vinu Rajashekhar 2016-06-22 11:34:57 -08:00 committed by TensorFlower Gardener
parent 2729f5c661
commit 7614900108

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
@ -124,21 +125,23 @@ Status RunInitOp(const StringPiece export_dir,
tensorflow::Status LoadSessionBundleFromPath(
const tensorflow::SessionOptions& options, const StringPiece export_dir,
SessionBundle* bundle) {
SessionBundle* const bundle) {
LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
TF_RETURN_IF_ERROR(
GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
auto collection_def = bundle->meta_graph_def.collection_def();
if (collection_def.find(kGraphKey) != collection_def.end()) {
const auto& collection_def_map = bundle->meta_graph_def.collection_def();
const auto graph_it = bundle->meta_graph_def.collection_def().find(kGraphKey);
if (graph_it != collection_def_map.end()) {
const CollectionDef& graph_collection_def = graph_it->second;
// Use serving graph_def in MetaGraphDef collection_def.
if (collection_def[kGraphKey].any_list().value_size() != 1) {
if (graph_collection_def.any_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one serving GraphDef in : ",
bundle->meta_graph_def.DebugString()));
}
tensorflow::GraphDef graph_def;
collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def);
graph_collection_def.any_list().value(0).UnpackTo(&graph_def);
TF_RETURN_IF_ERROR(
CreateSessionFromGraphDef(options, graph_def, &bundle->session));
} else {
@ -149,11 +152,14 @@ tensorflow::Status LoadSessionBundleFromPath(
}
std::vector<AssetFile> asset_files;
auto any_assets = collection_def[kAssetsKey].any_list().value();
for (const auto any_asset : any_assets) {
AssetFile asset_file;
any_asset.UnpackTo(&asset_file);
asset_files.push_back(asset_file);
const auto assets_it = collection_def_map.find(kAssetsKey);
if (assets_it != collection_def_map.end()) {
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFile asset_file;
any_asset.UnpackTo(&asset_file);
asset_files.push_back(asset_file);
}
}
TF_RETURN_IF_ERROR(
@ -162,14 +168,15 @@ tensorflow::Status LoadSessionBundleFromPath(
bundle->meta_graph_def.saver_def().filename_tensor_name(),
bundle->session.get()));
if (collection_def.find(kInitOpKey) != collection_def.end()) {
if (collection_def[kInitOpKey].node_list().value_size() != 1) {
const auto init_op_it = collection_def_map.find(kInitOpKey);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one serving init op in : ",
bundle->meta_graph_def.DebugString()));
}
return RunInitOp(export_dir, asset_files,
collection_def[kInitOpKey].node_list().value(0),
init_op_it->second.node_list().value(0),
bundle->session.get());
}