Add a way to specify the optimization order; refactor and add constant folding to meta optimizer.

Change: 152193646
This commit is contained in:
Yao Zhang 2017-04-04 15:13:44 -08:00 committed by TensorFlower Gardener
parent 83afdc92d9
commit 8f74d595ef
5 changed files with 59 additions and 12 deletions

View File

@ -17,6 +17,7 @@ filegroup(
srcs = glob(
[
"*_optimizer.*",
"constant_folding.*",
"model_pruner.*",
"graph_rewriter.*",
],
@ -175,6 +176,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
":constant_folding",
":graph_optimizer",
":layout_optimizer",
":model_pruner",

View File

@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase {
Tensor* tensor) override {
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
}
*tensor = parsed;
return Status::OK();

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/lib/core/status.h"
@ -21,25 +23,64 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
const string& optimizer) {
VLOG(1) << "Adding graph optimization pass: " << optimizer;
std::unique_ptr<GraphOptimizer> graph_optimizer;
if (optimizer == "pruning") {
graph_optimizer.reset(new ModelPruner());
}
if (optimizer == "constfold") {
graph_optimizer.reset(new ConstantFolding());
}
if (optimizer == "layout") {
graph_optimizer.reset(new LayoutOptimizer());
}
return graph_optimizer;
}
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
bool already_optimized = false;
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
if (cfg_.optimizers().empty()) {
if (!cfg_.disable_model_pruning()) {
already_optimized = true;
ModelPruner pruner;
TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph));
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
}
if (cfg_.constant_folding()) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
}
if (cfg_.optimize_tensor_layout()) {
LayoutOptimizer layout_optimizer;
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
} else {
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
for (const auto& optimizer : cfg_.optimizers()) {
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
optimizers.push_back(NewOptimizer(optimizer));
}
}
}
if (optimizers.empty()) {
*optimized_graph = item.graph;
return Status::OK();
}
bool already_optimized = false;
for (const auto& optimizer : optimizers) {
if (!already_optimized) {
return layout_optimizer.Optimize(nullptr, item, optimized_graph);
TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
already_optimized = true;
} else {
GrapplerItem optimized_item = item;
optimized_item.graph = *optimized_graph;
return layout_optimizer.Optimize(nullptr, optimized_item,
optimized_graph);
TF_RETURN_IF_ERROR(
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
}
}
return Status::OK();
}

View File

@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
RewriterConfig cfg_;
};

View File

@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework";
message RewriterConfig {
bool optimize_tensor_layout = 1;
bool disable_model_pruning = 2;
bool constant_folding = 3;
// If non-empty, will use this as an alternative way to specify a list of
// optimizations to turn on and the order of the optimizations.
repeated string optimizers = 100;
}