Add an option to apply ModelPruner when building a grappler item and an option to provide specific feed nodes to the item builder.
PiperOrigin-RevId: 171758733
This commit is contained in:
parent
010506f4fe
commit
3601966630
@ -100,6 +100,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/inputs:utils",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/inputs/utils.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
@ -133,12 +134,24 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
||||
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
|
||||
|
||||
// Optimize the graph.
|
||||
GraphOptimizer optimizer(*optimizer_opts);
|
||||
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
|
||||
optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
|
||||
graphptr->ToGraphDef(output_graph_def);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Applies the same graph pruning logic to the graph as Session.Run in TF.
|
||||
// If the returned status is not OK, item state may be inconsistent.
|
||||
Status PruneGraph(GrapplerItem* item) {
|
||||
ModelPruner pruner;
|
||||
GraphDef pruned_graph;
|
||||
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
|
||||
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
|
||||
item->graph = std::move(pruned_graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
@ -152,6 +165,18 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
new_item->id = id;
|
||||
new_item->graph = meta_graph.graph_def();
|
||||
|
||||
// Fill in feed nodes from config, if any provided.
|
||||
for (const auto& feed_node : cfg.feed_nodes) {
|
||||
const string feed_name = NodeName(feed_node);
|
||||
if (feed_name.empty()) {
|
||||
LOG(ERROR) << "Invalid feed node name " << feed_node
|
||||
<< ", skipping this input.";
|
||||
return nullptr;
|
||||
}
|
||||
LOG(INFO) << "Will use feed node " << feed_name;
|
||||
new_item->feed.emplace_back(feed_name, Tensor());
|
||||
}
|
||||
|
||||
// Attempt to detect the fetch node(s).
|
||||
if (meta_graph.collection_def().count("train_op") > 0) {
|
||||
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
|
||||
@ -339,9 +364,23 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor fake_input(type, shape);
|
||||
InitializeTensor(type, &fake_input);
|
||||
|
||||
if (cfg.feed_nodes.empty()) {
|
||||
// No specific feed nodes were given. Assume all placeholders are fed.
|
||||
new_item->feed.emplace_back(node.name(), fake_input);
|
||||
} else if (cfg.feed_nodes.count(node.name()) > 0) {
|
||||
// If specific feed nodes were given, only update their tensors.
|
||||
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
|
||||
[&node](std::pair<string, Tensor>& f) {
|
||||
return f.first == node.name();
|
||||
});
|
||||
QCHECK(it != new_item->feed.end());
|
||||
it->second = fake_input;
|
||||
}
|
||||
|
||||
// Set the shape of the node in the graph. This is needed for statically
|
||||
// inferring shapes and is a no-op when dynamically inferring shapes as
|
||||
// the Placeholder shape will match the shape passed from new_item->feed.
|
||||
@ -418,6 +457,16 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (cfg.prune_graph) {
|
||||
VLOG(1) << "Pruning graph...";
|
||||
auto status = PruneGraph(new_item.get());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Pruning failed: " << status.error_message();
|
||||
return nullptr;
|
||||
}
|
||||
VLOG(1) << "Pruning ran succesfully.";
|
||||
}
|
||||
|
||||
// Validate feed, fetch and init nodes
|
||||
std::unordered_set<string> nodes;
|
||||
for (const auto& node : new_item->graph.node()) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
|
||||
@ -45,6 +46,10 @@ struct ItemConfig {
|
||||
bool erase_noinline_attributes = false;
|
||||
// If non-empty, override the directory of asset paths.
|
||||
string assets_directory_override;
|
||||
// If true, runs ModelPruner on the graph.
|
||||
bool prune_graph = false;
|
||||
// Override feed nodes list.
|
||||
std::set<string> feed_nodes;
|
||||
};
|
||||
|
||||
// Factory method for creating a GrapplerItem from a MetaGraphDef.
|
||||
|
Loading…
Reference in New Issue
Block a user