diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index 8715492af4b..d706d763856 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" @@ -124,21 +125,23 @@ Status RunInitOp(const StringPiece export_dir, tensorflow::Status LoadSessionBundleFromPath( const tensorflow::SessionOptions& options, const StringPiece export_dir, - SessionBundle* bundle) { + SessionBundle* const bundle) { LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir; TF_RETURN_IF_ERROR( GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def))); - auto collection_def = bundle->meta_graph_def.collection_def(); - if (collection_def.find(kGraphKey) != collection_def.end()) { + const auto& collection_def_map = bundle->meta_graph_def.collection_def(); + const auto graph_it = bundle->meta_graph_def.collection_def().find(kGraphKey); + if (graph_it != collection_def_map.end()) { + const CollectionDef& graph_collection_def = graph_it->second; // Use serving graph_def in MetaGraphDef collection_def. - if (collection_def[kGraphKey].any_list().value_size() != 1) { + if (graph_collection_def.any_list().value_size() != 1) { return errors::FailedPrecondition( strings::StrCat("Expected exactly one serving GraphDef in : ", bundle->meta_graph_def.DebugString())); } tensorflow::GraphDef graph_def; - collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def); + graph_collection_def.any_list().value(0).UnpackTo(&graph_def); TF_RETURN_IF_ERROR( CreateSessionFromGraphDef(options, graph_def, &bundle->session)); } else { @@ -149,11 +152,14 @@ tensorflow::Status LoadSessionBundleFromPath( } std::vector<AssetFile> asset_files; - auto any_assets = collection_def[kAssetsKey].any_list().value(); - for (const auto any_asset : any_assets) { - AssetFile asset_file; - any_asset.UnpackTo(&asset_file); - asset_files.push_back(asset_file); + const auto assets_it = collection_def_map.find(kAssetsKey); + if (assets_it != collection_def_map.end()) { + const auto& any_assets = assets_it->second.any_list().value(); + for (const auto& any_asset : any_assets) { + AssetFile asset_file; + any_asset.UnpackTo(&asset_file); + asset_files.push_back(asset_file); + } } TF_RETURN_IF_ERROR( @@ -162,14 +168,15 @@ tensorflow::Status LoadSessionBundleFromPath( bundle->meta_graph_def.saver_def().filename_tensor_name(), bundle->session.get())); - if (collection_def.find(kInitOpKey) != collection_def.end()) { - if (collection_def[kInitOpKey].node_list().value_size() != 1) { + const auto init_op_it = collection_def_map.find(kInitOpKey); + if (init_op_it != collection_def_map.end()) { + if (init_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( strings::StrCat("Expected exactly one serving init op in : ", bundle->meta_graph_def.DebugString())); } return RunInitOp(export_dir, asset_files, - collection_def[kInitOpKey].node_list().value(0), + init_op_it->second.node_list().value(0), bundle->session.get()); }