Expose OptimizeGraph() and enable function inlining control.

PiperOrigin-RevId: 243700639
This commit is contained in:
Lifeng Nai 2019-04-15 15:30:39 -07:00 committed by TensorFlower Gardener
parent 4307698dce
commit dbc271b664
2 changed files with 63 additions and 48 deletions

View File

@ -75,14 +75,49 @@ void InitializeTensor(DataType type, Tensor* tensor) {
}
}
// Optimize the graph def (including function inlining and other optimizations).
// Applies the same graph pruning logic to the graph as Session.Run in TF.
// If the returned status is not OK, item state may be inconsistent.
Status PruneGraph(GrapplerItem* item) {
ModelPruner pruner;
GraphDef pruned_graph;
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
item->graph = std::move(pruned_graph);
return Status::OK();
}
// Replace any unknown dimensions in a shape with
// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
const TensorShapeProto& shape_pb_in,
TensorShapeProto* shape_pb_out,
TensorShape* shape_out) {
std::vector<int32> dims;
for (const auto& dim_proto : shape_pb_in.dim()) {
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
dim_proto.size() == -1) {
dims.push_back(cfg.placeholder_unknown_output_shape_dim);
shape_pb_out->add_dim()->set_size(
cfg.placeholder_unknown_output_shape_dim);
} else {
dims.push_back(std::max<int32>(1, dim_proto.size()));
shape_pb_out->add_dim()->set_size(dim_proto.size());
}
}
return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
}
} // namespace
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
GraphDef* output_graph_def,
const ItemConfig& cfg) {
// This is a temporary change that optimizes the graph in context of a single
// gpu machine. Down the line, we may want to make grappler_item_builder aware
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
// order to get the correct session options and environment, and performing the
// correct optimizations.
Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
const ItemConfig& cfg) {
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated
// in order to get the correct session options and environment, and performing
// the correct optimizations.
if (!cfg.apply_optimizations && !cfg.erase_noinline_attributes) {
return Status::OK();
}
@ -124,6 +159,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
} else {
optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0);
}
optimizer_opts->set_do_function_inlining(cfg.inline_functions);
// Create the function library runtime.
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
@ -152,41 +188,6 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
0, true);
}
// Applies the same graph pruning logic to the graph as Session.Run in TF.
// If the returned status is not OK, item state may be inconsistent.
Status PruneGraph(GrapplerItem* item) {
ModelPruner pruner;
GraphDef pruned_graph;
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
item->graph = std::move(pruned_graph);
return Status::OK();
}
// Replace any unknown dimensions in a shape with
// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
const TensorShapeProto& shape_pb_in,
TensorShapeProto* shape_pb_out,
TensorShape* shape_out) {
std::vector<int32> dims;
for (const auto& dim_proto : shape_pb_in.dim()) {
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
dim_proto.size() == -1) {
dims.push_back(cfg.placeholder_unknown_output_shape_dim);
shape_pb_out->add_dim()->set_size(
cfg.placeholder_unknown_output_shape_dim);
} else {
dims.push_back(std::max<int32>(1, dim_proto.size()));
shape_pb_out->add_dim()->set_size(dim_proto.size());
}
}
return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
}
} // namespace
// static
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
if (id.empty()) {
@ -592,15 +593,15 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
// Optimize the graph (function inlining, l1 optimizations, etc).
VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
<< new_item->graph.node_size();
Status optimize_status =
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
return nullptr;
}
VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
<< new_item->graph.node_size();
if (cfg.prune_graph) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
@ -38,8 +39,6 @@ struct ItemConfig {
// Dimension to use if a placeholder node has an _output_shapes attribute with
// a dimension of -1.
int placeholder_unknown_output_shape_dim = -1;
// If true, does L1 optimizations.
bool apply_optimizations = false;
// If true, erases all "_noinline" attributes from user-defined functions.
// Has no effect if "inline_functions" is disabled.
bool erase_noinline_attributes = false;
@ -51,8 +50,23 @@ struct ItemConfig {
std::set<string> feed_nodes;
// Override fetch nodes list.
std::set<string> fetch_nodes;
// Configs for graph optimizations from common_runtime. This is NOT Grappler
// function optimizer. When Grappler is invoked at runtime, it is typically
// running after common_runtime pass.
//
// If true, does L1 optimizations.
bool apply_optimizations = false;
// If true, does function inlining.
bool inline_functions = false;
};
// Method for optimizing the graph def (including function inlining and other
// optimizations). This is optimizations from common_runtime, NOT Grappler
// function optimizer.
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
GraphDef* output_graph_def, const ItemConfig& cfg);
// Factory method for creating a GrapplerItem from a MetaGraphDef.
// Returns nullptr if the given meta_graph cannot be converted.
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(