Expose OptimizeGraph() and enable function inlining control.
PiperOrigin-RevId: 243700639
This commit is contained in:
parent
4307698dce
commit
dbc271b664
@ -75,14 +75,49 @@ void InitializeTensor(DataType type, Tensor* tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize the graph def (including function inlining and other optimizations).
|
||||
// Applies the same graph pruning logic to the graph as Session.Run in TF.
|
||||
// If the returned status is not OK, item state may be inconsistent.
|
||||
Status PruneGraph(GrapplerItem* item) {
|
||||
ModelPruner pruner;
|
||||
GraphDef pruned_graph;
|
||||
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
|
||||
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
|
||||
item->graph = std::move(pruned_graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replace any unknown dimensions in a shape with
|
||||
// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
|
||||
Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
|
||||
const TensorShapeProto& shape_pb_in,
|
||||
TensorShapeProto* shape_pb_out,
|
||||
TensorShape* shape_out) {
|
||||
std::vector<int32> dims;
|
||||
for (const auto& dim_proto : shape_pb_in.dim()) {
|
||||
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
|
||||
dim_proto.size() == -1) {
|
||||
dims.push_back(cfg.placeholder_unknown_output_shape_dim);
|
||||
shape_pb_out->add_dim()->set_size(
|
||||
cfg.placeholder_unknown_output_shape_dim);
|
||||
} else {
|
||||
dims.push_back(std::max<int32>(1, dim_proto.size()));
|
||||
shape_pb_out->add_dim()->set_size(dim_proto.size());
|
||||
}
|
||||
}
|
||||
return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||
GraphDef* output_graph_def,
|
||||
const ItemConfig& cfg) {
|
||||
// This is a temporary change that optimizes the graph in context of a single
|
||||
// gpu machine. Down the line, we may want to make grappler_item_builder aware
|
||||
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
|
||||
// order to get the correct session options and environment, and performing the
|
||||
// correct optimizations.
|
||||
Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
||||
const ItemConfig& cfg) {
|
||||
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated
|
||||
// in order to get the correct session options and environment, and performing
|
||||
// the correct optimizations.
|
||||
|
||||
if (!cfg.apply_optimizations && !cfg.erase_noinline_attributes) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -124,6 +159,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
||||
} else {
|
||||
optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0);
|
||||
}
|
||||
optimizer_opts->set_do_function_inlining(cfg.inline_functions);
|
||||
|
||||
// Create the function library runtime.
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
@ -152,41 +188,6 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
||||
0, true);
|
||||
}
|
||||
|
||||
// Applies the same graph pruning logic to the graph as Session.Run in TF.
|
||||
// If the returned status is not OK, item state may be inconsistent.
|
||||
Status PruneGraph(GrapplerItem* item) {
|
||||
ModelPruner pruner;
|
||||
GraphDef pruned_graph;
|
||||
Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
|
||||
TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
|
||||
item->graph = std::move(pruned_graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replace any unknown dimensions in a shape with
|
||||
// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
|
||||
Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
|
||||
const TensorShapeProto& shape_pb_in,
|
||||
TensorShapeProto* shape_pb_out,
|
||||
TensorShape* shape_out) {
|
||||
std::vector<int32> dims;
|
||||
for (const auto& dim_proto : shape_pb_in.dim()) {
|
||||
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
|
||||
dim_proto.size() == -1) {
|
||||
dims.push_back(cfg.placeholder_unknown_output_shape_dim);
|
||||
shape_pb_out->add_dim()->set_size(
|
||||
cfg.placeholder_unknown_output_shape_dim);
|
||||
} else {
|
||||
dims.push_back(std::max<int32>(1, dim_proto.size()));
|
||||
shape_pb_out->add_dim()->set_size(dim_proto.size());
|
||||
}
|
||||
}
|
||||
return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
|
||||
if (id.empty()) {
|
||||
@ -592,15 +593,15 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
}
|
||||
|
||||
// Optimize the graph (function inlining, l1 optimizations, etc).
|
||||
VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
|
||||
VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
|
||||
<< new_item->graph.node_size();
|
||||
Status optimize_status =
|
||||
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
|
||||
RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
|
||||
if (!optimize_status.ok()) {
|
||||
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
|
||||
return nullptr;
|
||||
}
|
||||
VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
|
||||
VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
|
||||
<< new_item->graph.node_size();
|
||||
|
||||
if (cfg.prune_graph) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
|
||||
@ -38,8 +39,6 @@ struct ItemConfig {
|
||||
// Dimension to use if a placeholder node has an _output_shapes attribute with
|
||||
// a dimension of -1.
|
||||
int placeholder_unknown_output_shape_dim = -1;
|
||||
// If true, does L1 optimizations.
|
||||
bool apply_optimizations = false;
|
||||
// If true, erases all "_noinline" attributes from user-defined functions.
|
||||
// Has no effect if "inline_functions" is disabled.
|
||||
bool erase_noinline_attributes = false;
|
||||
@ -51,8 +50,23 @@ struct ItemConfig {
|
||||
std::set<string> feed_nodes;
|
||||
// Override fetch nodes list.
|
||||
std::set<string> fetch_nodes;
|
||||
|
||||
// Configs for graph optimizations from common_runtime. This is NOT Grappler
|
||||
// function optimizer. When Grappler is invoked at runtime, it is typically
|
||||
// running after common_runtime pass.
|
||||
//
|
||||
// If true, does L1 optimizations.
|
||||
bool apply_optimizations = false;
|
||||
// If true, does function inlining.
|
||||
bool inline_functions = false;
|
||||
};
|
||||
|
||||
// Method for optimizing the graph def (including function inlining and other
|
||||
// optimizations). This is optimizations from common_runtime, NOT Grappler
|
||||
// function optimizer.
|
||||
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||
GraphDef* output_graph_def, const ItemConfig& cfg);
|
||||
|
||||
// Factory method for creating a GrapplerItem from a MetaGraphDef.
|
||||
// Returns nullptr if the given meta_graph cannot be converted.
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
|
Loading…
Reference in New Issue
Block a user