Add an option to apply ModelPruner when building a grappler item and an option to provide specific feed nodes to the item builder.

PiperOrigin-RevId: 171758733
This commit is contained in:
Max Galkin 2017-10-10 17:22:48 -07:00 committed by TensorFlower Gardener
parent 010506f4fe
commit 3601966630
3 changed files with 57 additions and 2 deletions

View File

@ -100,6 +100,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/inputs:utils",
"//tensorflow/core/grappler/optimizers:model_pruner",
],
)

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/inputs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/protobuf_internal.h"
@ -133,12 +134,24 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
// Optimize the graph.
GraphOptimizer optimizer(*optimizer_opts);
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def);
return Status::OK();
}
// Applies the same graph pruning logic to the graph as Session.Run in TF.
// If the returned status is not OK, item state may be inconsistent.
Status PruneGraph(GrapplerItem* item) {
ModelPruner pruner;
GraphDef pruned_graph;
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
item->graph = std::move(pruned_graph);
return Status::OK();
}
} // namespace
// static
@ -152,6 +165,18 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
new_item->id = id;
new_item->graph = meta_graph.graph_def();
// Fill in feed nodes from config, if any provided.
for (const auto& feed_node : cfg.feed_nodes) {
const string feed_name = NodeName(feed_node);
if (feed_name.empty()) {
LOG(ERROR) << "Invalid feed node name " << feed_node
<< ", skipping this input.";
return nullptr;
}
LOG(INFO) << "Will use feed node " << feed_name;
new_item->feed.emplace_back(feed_name, Tensor());
}
// Attempt to detect the fetch node(s).
if (meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
@ -339,9 +364,23 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
}
Tensor fake_input(type, shape);
InitializeTensor(type, &fake_input);
new_item->feed.emplace_back(node.name(), fake_input);
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
new_item->feed.emplace_back(node.name(), fake_input);
} else if (cfg.feed_nodes.count(node.name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
[&node](std::pair<string, Tensor>& f) {
return f.first == node.name();
});
QCHECK(it != new_item->feed.end());
it->second = fake_input;
}
// Set the shape of the node in the graph. This is needed for statically
// inferring shapes and is a no-op when dynamically inferring shapes as
// the Placeholder shape will match the shape passed from new_item->feed.
@ -418,6 +457,16 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
return nullptr;
}
if (cfg.prune_graph) {
VLOG(1) << "Pruning graph...";
auto status = PruneGraph(new_item.get());
if (!status.ok()) {
LOG(ERROR) << "Pruning failed: " << status.error_message();
return nullptr;
}
VLOG(1) << "Pruning ran succesfully.";
}
// Validate feed, fetch and init nodes
std::unordered_set<string> nodes;
for (const auto& node : new_item->graph.node()) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_BUILDER_H_
#include <memory>
#include <set>
#include <string>
#include "tensorflow/core/grappler/grappler_item.h"
@ -45,6 +46,10 @@ struct ItemConfig {
bool erase_noinline_attributes = false;
// If non-empty, override the directory of asset paths.
string assets_directory_override;
// If true, runs ModelPruner on the graph.
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.