it can be used by saved model MLIR importer. PiperOrigin-RevId: 316160822 Change-Id: Ic6e1c09715f33f21c3f015c4e90145987f354389
91 lines
3.4 KiB
C++
91 lines
3.4 KiB
C++
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/cc/saved_model/loader_util.h"
|
|
|
|
#include <vector>
|
|
|
|
#include "tensorflow/cc/saved_model/constants.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/platform/errors.h"
|
|
#include "tensorflow/core/platform/protobuf_internal.h"
|
|
|
|
namespace tensorflow {
|
|
namespace internal {
|
|
|
|
// A SavedModel may store the name of the initialization op to run in the
|
|
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
|
// exists, then the collection must contain exactly one op.
|
|
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
|
string* init_op_name) {
|
|
const auto& sig_def_map = meta_graph_def.signature_def();
|
|
const auto& init_op_sig_it =
|
|
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
|
if (init_op_sig_it != sig_def_map.end()) {
|
|
*init_op_name = init_op_sig_it->second.outputs()
|
|
.find(kSavedModelInitOpSignatureKey)
|
|
->second.name();
|
|
return Status::OK();
|
|
}
|
|
|
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
|
string init_op_collection_key;
|
|
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
|
collection_def_map.end()) {
|
|
init_op_collection_key = kSavedModelMainOpKey;
|
|
} else {
|
|
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
|
}
|
|
|
|
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
|
if (init_op_it != collection_def_map.end()) {
|
|
if (init_op_it->second.node_list().value_size() != 1) {
|
|
return errors::FailedPrecondition(
|
|
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
|
}
|
|
*init_op_name = init_op_it->second.node_list().value(0);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
|
std::vector<AssetFileDef>* asset_file_defs) {
|
|
// With SavedModel v2, we write asset file def into metagraph instead of
|
|
// collection, so read from metagraph first.
|
|
if (meta_graph_def.asset_file_def_size() > 0) {
|
|
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
|
asset_file_defs->push_back(asset);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
// Fall back to read from collection to be backward compatible with v1.
|
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
|
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
|
if (assets_it == collection_def_map.end()) {
|
|
return Status::OK();
|
|
}
|
|
const auto& any_assets = assets_it->second.any_list().value();
|
|
for (const auto& any_asset : any_assets) {
|
|
AssetFileDef asset_file_def;
|
|
TF_RETURN_IF_ERROR(
|
|
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
|
asset_file_defs->push_back(asset_file_def);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace internal
|
|
} // namespace tensorflow
|