Added the ability to load grappler items from a the metagraphdef files contained in a TensorFlow checkpoint directory.
PiperOrigin-RevId: 165774826
This commit is contained in:
parent
402d2522f7
commit
8454e3ffa8
@ -505,6 +505,7 @@ $(wildcard tensorflow/core/user_ops/*.cu.cc) \
|
||||
$(wildcard tensorflow/core/common_runtime/gpu/*) \
|
||||
$(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \
|
||||
$(wildcard tensorflow/core/grappler/clusters/single_machine.*)
|
||||
# Filter out all the excluded files.
|
||||
TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
|
||||
|
@ -66,3 +66,20 @@ cc_library(
|
||||
"//tensorflow/core/kernels:aggregate_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "file_input_yielder",
|
||||
srcs = ["file_input_yielder.cc"],
|
||||
hdrs = [
|
||||
"file_input_yielder.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":input_yielder",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:grappler_item_builder",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
134
tensorflow/core/grappler/inputs/file_input_yielder.cc
Normal file
134
tensorflow/core/grappler/inputs/file_input_yielder.cc
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/inputs/file_input_yielder.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
FileInputYielder::FileInputYielder(const std::vector<string>& filenames,
|
||||
size_t max_iterations)
|
||||
: filenames_(filenames),
|
||||
current_file_(0),
|
||||
current_iteration_(0),
|
||||
max_iterations_(max_iterations),
|
||||
bad_inputs_(0) {
|
||||
CHECK_GT(filenames.size(), 0) << "List of filenames is empty.";
|
||||
}
|
||||
|
||||
bool FileInputYielder::NextItem(GrapplerItem* item) {
|
||||
if (filenames_.size() == bad_inputs_) {
|
||||
// All the input files are bad, give up.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (current_file_ >= filenames_.size()) {
|
||||
if (current_iteration_ >= max_iterations_) {
|
||||
return false;
|
||||
} else {
|
||||
++current_iteration_;
|
||||
current_file_ = 0;
|
||||
bad_inputs_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const string& filename = filenames_[current_file_];
|
||||
++current_file_;
|
||||
|
||||
if (!Env::Default()->FileExists(filename).ok()) {
|
||||
LOG(WARNING) << "Skipping non existent file " << filename;
|
||||
// Attempt to process the next item on the list
|
||||
bad_inputs_ += 1;
|
||||
return NextItem(item);
|
||||
}
|
||||
|
||||
LOG(INFO) << "Loading model from " << filename;
|
||||
|
||||
MetaGraphDef metagraph;
|
||||
Status s = ReadBinaryProto(Env::Default(), filename, &metagraph);
|
||||
if (!s.ok()) {
|
||||
s = ReadTextProto(Env::Default(), filename, &metagraph);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to read MetaGraphDef from " << filename << ": "
|
||||
<< s.ToString();
|
||||
// Attempt to process the next item on the list
|
||||
bad_inputs_ += 1;
|
||||
return NextItem(item);
|
||||
}
|
||||
|
||||
if (metagraph.collection_def().count("train_op") == 0 ||
|
||||
!metagraph.collection_def().at("train_op").has_node_list() ||
|
||||
metagraph.collection_def().at("train_op").node_list().value_size() == 0) {
|
||||
LOG(ERROR) << "No train op specified";
|
||||
bad_inputs_ += 1;
|
||||
metagraph = MetaGraphDef();
|
||||
return NextItem(item);
|
||||
} else {
|
||||
std::unordered_set<string> train_ops;
|
||||
for (const string& val :
|
||||
metagraph.collection_def().at("train_op").node_list().value()) {
|
||||
train_ops.insert(NodeName(val));
|
||||
}
|
||||
std::unordered_set<string> train_ops_found;
|
||||
for (auto& node : metagraph.graph_def().node()) {
|
||||
if (train_ops.find(node.name()) != train_ops.end()) {
|
||||
train_ops_found.insert(node.name());
|
||||
}
|
||||
}
|
||||
if (train_ops_found.size() != train_ops.size()) {
|
||||
for (const auto& train_op : train_ops) {
|
||||
if (train_ops_found.find(train_op) != train_ops_found.end()) {
|
||||
LOG(ERROR) << "Non existent train op specified: " << train_op;
|
||||
}
|
||||
}
|
||||
bad_inputs_ += 1;
|
||||
metagraph = MetaGraphDef();
|
||||
return NextItem(item);
|
||||
}
|
||||
}
|
||||
|
||||
const string id =
|
||||
strings::StrCat(Fingerprint64(metagraph.SerializeAsString()));
|
||||
|
||||
ItemConfig cfg;
|
||||
std::unique_ptr<GrapplerItem> new_item =
|
||||
GrapplerItemFromMetaGraphDef(id, metagraph, cfg);
|
||||
if (new_item == nullptr) {
|
||||
bad_inputs_ += 1;
|
||||
metagraph = MetaGraphDef();
|
||||
return NextItem(item);
|
||||
}
|
||||
|
||||
*item = std::move(*new_item);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
56
tensorflow/core/grappler/inputs/file_input_yielder.h
Normal file
56
tensorflow/core/grappler/inputs/file_input_yielder.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// The file input provides a mechanism to feed grappler with existing TensorFlow
|
||||
// graphs stored in TensorFlow checkpoints. Note that at this point the weights
|
||||
// that may be stored in the checkpoint are not restored in order to speedup the
|
||||
// initialization.
|
||||
|
||||
#ifndef LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_
|
||||
#define LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/grappler/inputs/input_yielder.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class GrapplerItem;
|
||||
|
||||
class FileInputYielder : public InputYielder {
|
||||
public:
|
||||
// Iterates over the files specified in the list of 'filename' up to
|
||||
// 'max_iterations' times.
|
||||
explicit FileInputYielder(
|
||||
const std::vector<string>& filenames,
|
||||
size_t max_iterations = std::numeric_limits<size_t>::max());
|
||||
bool NextItem(GrapplerItem* item) override;
|
||||
|
||||
private:
|
||||
const std::vector<string> filenames_;
|
||||
size_t current_file_;
|
||||
size_t current_iteration_;
|
||||
size_t max_iterations_;
|
||||
|
||||
size_t bad_inputs_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_
|
Loading…
x
Reference in New Issue
Block a user