Add a way to specify the optimization order; refactor and add constant folding to meta optimizer.
Change: 152193646
This commit is contained in:
parent
83afdc92d9
commit
8f74d595ef
@ -17,6 +17,7 @@ filegroup(
|
||||
srcs = glob(
|
||||
[
|
||||
"*_optimizer.*",
|
||||
"constant_folding.*",
|
||||
"model_pruner.*",
|
||||
"graph_rewriter.*",
|
||||
],
|
||||
@ -175,6 +176,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":constant_folding",
|
||||
":graph_optimizer",
|
||||
":layout_optimizer",
|
||||
":model_pruner",
|
||||
|
@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase {
|
||||
Tensor* tensor) override {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
|
||||
}
|
||||
*tensor = parsed;
|
||||
return Status::OK();
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -21,25 +23,64 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
const string& optimizer) {
|
||||
VLOG(1) << "Adding graph optimization pass: " << optimizer;
|
||||
std::unique_ptr<GraphOptimizer> graph_optimizer;
|
||||
if (optimizer == "pruning") {
|
||||
graph_optimizer.reset(new ModelPruner());
|
||||
}
|
||||
if (optimizer == "constfold") {
|
||||
graph_optimizer.reset(new ConstantFolding());
|
||||
}
|
||||
if (optimizer == "layout") {
|
||||
graph_optimizer.reset(new LayoutOptimizer());
|
||||
}
|
||||
return graph_optimizer;
|
||||
}
|
||||
|
||||
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
bool already_optimized = false;
|
||||
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
|
||||
if (cfg_.optimizers().empty()) {
|
||||
if (!cfg_.disable_model_pruning()) {
|
||||
already_optimized = true;
|
||||
ModelPruner pruner;
|
||||
TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph));
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
|
||||
}
|
||||
if (cfg_.constant_folding()) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
|
||||
}
|
||||
if (cfg_.optimize_tensor_layout()) {
|
||||
LayoutOptimizer layout_optimizer;
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||
}
|
||||
} else {
|
||||
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
|
||||
for (const auto& optimizer : cfg_.optimizers()) {
|
||||
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
|
||||
optimizers.push_back(NewOptimizer(optimizer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (optimizers.empty()) {
|
||||
*optimized_graph = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool already_optimized = false;
|
||||
for (const auto& optimizer : optimizers) {
|
||||
if (!already_optimized) {
|
||||
return layout_optimizer.Optimize(nullptr, item, optimized_graph);
|
||||
TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
|
||||
already_optimized = true;
|
||||
} else {
|
||||
GrapplerItem optimized_item = item;
|
||||
optimized_item.graph = *optimized_graph;
|
||||
return layout_optimizer.Optimize(nullptr, optimized_item,
|
||||
optimized_graph);
|
||||
TF_RETURN_IF_ERROR(
|
||||
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer {
|
||||
const GraphDef& optimized_graph, double result) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
|
||||
RewriterConfig cfg_;
|
||||
};
|
||||
|
||||
|
@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework";
|
||||
message RewriterConfig {
|
||||
bool optimize_tensor_layout = 1;
|
||||
bool disable_model_pruning = 2;
|
||||
bool constant_folding = 3;
|
||||
// If non-empty, will use this as an alternative way to specify a list of
|
||||
// optimizations to turn on and the order of the optimizations.
|
||||
repeated string optimizers = 100;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user