From 8454e3ffa850527d6fecefb3110e7b1a4b6939f5 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 18 Aug 2017 17:34:57 -0700 Subject: [PATCH] Added the ability to load grappler items from a the metagraphdef files contained in a TensorFlow checkpoint directory. PiperOrigin-RevId: 165774826 --- tensorflow/contrib/makefile/Makefile | 1 + tensorflow/core/grappler/inputs/BUILD | 17 +++ .../grappler/inputs/file_input_yielder.cc | 134 ++++++++++++++++++ .../core/grappler/inputs/file_input_yielder.h | 56 ++++++++ 4 files changed, 208 insertions(+) create mode 100644 tensorflow/core/grappler/inputs/file_input_yielder.cc create mode 100644 tensorflow/core/grappler/inputs/file_input_yielder.h diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index f8837e3f586..98af47d7288 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -505,6 +505,7 @@ $(wildcard tensorflow/core/user_ops/*.cu.cc) \ $(wildcard tensorflow/core/common_runtime/gpu/*) \ $(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \ $(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \ +$(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \ $(wildcard tensorflow/core/grappler/clusters/single_machine.*) # Filter out all the excluded files. TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD index 5c70f409697..915a3e28f88 100644 --- a/tensorflow/core/grappler/inputs/BUILD +++ b/tensorflow/core/grappler/inputs/BUILD @@ -66,3 +66,20 @@ cc_library( "//tensorflow/core/kernels:aggregate_ops", ], ) + +cc_library( + name = "file_input_yielder", + srcs = ["file_input_yielder.cc"], + hdrs = [ + "file_input_yielder.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":input_yielder", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler:utils", + ], +) diff --git a/tensorflow/core/grappler/inputs/file_input_yielder.cc b/tensorflow/core/grappler/inputs/file_input_yielder.cc new file mode 100644 index 00000000000..e63a38c9746 --- /dev/null +++ b/tensorflow/core/grappler/inputs/file_input_yielder.cc @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/inputs/file_input_yielder.h" + +#include +#include +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { + +FileInputYielder::FileInputYielder(const std::vector& filenames, + size_t max_iterations) + : filenames_(filenames), + current_file_(0), + current_iteration_(0), + max_iterations_(max_iterations), + bad_inputs_(0) { + CHECK_GT(filenames.size(), 0) << "List of filenames is empty."; +} + +bool FileInputYielder::NextItem(GrapplerItem* item) { + if (filenames_.size() == bad_inputs_) { + // All the input files are bad, give up. + return false; + } + + if (current_file_ >= filenames_.size()) { + if (current_iteration_ >= max_iterations_) { + return false; + } else { + ++current_iteration_; + current_file_ = 0; + bad_inputs_ = 0; + } + } + + const string& filename = filenames_[current_file_]; + ++current_file_; + + if (!Env::Default()->FileExists(filename).ok()) { + LOG(WARNING) << "Skipping non existent file " << filename; + // Attempt to process the next item on the list + bad_inputs_ += 1; + return NextItem(item); + } + + LOG(INFO) << "Loading model from " << filename; + + MetaGraphDef metagraph; + Status s = ReadBinaryProto(Env::Default(), filename, &metagraph); + if (!s.ok()) { + s = ReadTextProto(Env::Default(), filename, &metagraph); + } + if (!s.ok()) { + LOG(WARNING) << "Failed to read MetaGraphDef from " << filename << ": " + << s.ToString(); + // Attempt to process the next item on the list + bad_inputs_ += 1; + return NextItem(item); + } + + if (metagraph.collection_def().count("train_op") == 0 || + !metagraph.collection_def().at("train_op").has_node_list() || + metagraph.collection_def().at("train_op").node_list().value_size() == 0) { + LOG(ERROR) << "No train op specified"; + bad_inputs_ += 1; + metagraph = MetaGraphDef(); + return NextItem(item); + } else { + std::unordered_set train_ops; + for (const string& val : + metagraph.collection_def().at("train_op").node_list().value()) { + train_ops.insert(NodeName(val)); + } + std::unordered_set train_ops_found; + for (auto& node : metagraph.graph_def().node()) { + if (train_ops.find(node.name()) != train_ops.end()) { + train_ops_found.insert(node.name()); + } + } + if (train_ops_found.size() != train_ops.size()) { + for (const auto& train_op : train_ops) { + if (train_ops_found.find(train_op) != train_ops_found.end()) { + LOG(ERROR) << "Non existent train op specified: " << train_op; + } + } + bad_inputs_ += 1; + metagraph = MetaGraphDef(); + return NextItem(item); + } + } + + const string id = + strings::StrCat(Fingerprint64(metagraph.SerializeAsString())); + + ItemConfig cfg; + std::unique_ptr new_item = + GrapplerItemFromMetaGraphDef(id, metagraph, cfg); + if (new_item == nullptr) { + bad_inputs_ += 1; + metagraph = MetaGraphDef(); + return NextItem(item); + } + + *item = std::move(*new_item); + return true; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/file_input_yielder.h b/tensorflow/core/grappler/inputs/file_input_yielder.h new file mode 100644 index 00000000000..a17e1c9ff2a --- /dev/null +++ b/tensorflow/core/grappler/inputs/file_input_yielder.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The file input provides a mechanism to feed grappler with existing TensorFlow +// graphs stored in TensorFlow checkpoints. Note that at this point the weights +// that may be stored in the checkpoint are not restored in order to speedup the +// initialization. + +#ifndef LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ +#define LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ + +#include +#include +#include +#include "tensorflow/core/grappler/inputs/input_yielder.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +class GrapplerItem; + +class FileInputYielder : public InputYielder { + public: + // Iterates over the files specified in the list of 'filename' up to + // 'max_iterations' times. + explicit FileInputYielder( + const std::vector& filenames, + size_t max_iterations = std::numeric_limits::max()); + bool NextItem(GrapplerItem* item) override; + + private: + const std::vector filenames_; + size_t current_file_; + size_t current_iteration_; + size_t max_iterations_; + + size_t bad_inputs_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_