Add a way to specify the optimization order; refactor and add constant folding to meta optimizer.
Change: 152193646
This commit is contained in:
parent
83afdc92d9
commit
8f74d595ef
@ -17,6 +17,7 @@ filegroup(
|
|||||||
srcs = glob(
|
srcs = glob(
|
||||||
[
|
[
|
||||||
"*_optimizer.*",
|
"*_optimizer.*",
|
||||||
|
"constant_folding.*",
|
||||||
"model_pruner.*",
|
"model_pruner.*",
|
||||||
"graph_rewriter.*",
|
"graph_rewriter.*",
|
||||||
],
|
],
|
||||||
@ -175,6 +176,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":constant_folding",
|
||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
":layout_optimizer",
|
":layout_optimizer",
|
||||||
":model_pruner",
|
":model_pruner",
|
||||||
|
@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase {
|
|||||||
Tensor* tensor) override {
|
Tensor* tensor) override {
|
||||||
Tensor parsed(tensor_proto.dtype());
|
Tensor parsed(tensor_proto.dtype());
|
||||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
|
||||||
tensor_proto.DebugString());
|
|
||||||
}
|
}
|
||||||
*tensor = parsed;
|
*tensor = parsed;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -21,25 +23,64 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||||
|
const string& optimizer) {
|
||||||
|
VLOG(1) << "Adding graph optimization pass: " << optimizer;
|
||||||
|
std::unique_ptr<GraphOptimizer> graph_optimizer;
|
||||||
|
if (optimizer == "pruning") {
|
||||||
|
graph_optimizer.reset(new ModelPruner());
|
||||||
|
}
|
||||||
|
if (optimizer == "constfold") {
|
||||||
|
graph_optimizer.reset(new ConstantFolding());
|
||||||
|
}
|
||||||
|
if (optimizer == "layout") {
|
||||||
|
graph_optimizer.reset(new LayoutOptimizer());
|
||||||
|
}
|
||||||
|
return graph_optimizer;
|
||||||
|
}
|
||||||
|
|
||||||
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
bool already_optimized = false;
|
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
|
||||||
if (!cfg_.disable_model_pruning()) {
|
if (cfg_.optimizers().empty()) {
|
||||||
already_optimized = true;
|
if (!cfg_.disable_model_pruning()) {
|
||||||
ModelPruner pruner;
|
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
|
||||||
TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph));
|
}
|
||||||
|
if (cfg_.constant_folding()) {
|
||||||
|
optimizers.push_back(
|
||||||
|
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
|
||||||
|
}
|
||||||
|
if (cfg_.optimize_tensor_layout()) {
|
||||||
|
optimizers.push_back(
|
||||||
|
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
|
||||||
|
for (const auto& optimizer : cfg_.optimizers()) {
|
||||||
|
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
|
||||||
|
optimizers.push_back(NewOptimizer(optimizer));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (cfg_.optimize_tensor_layout()) {
|
|
||||||
LayoutOptimizer layout_optimizer;
|
if (optimizers.empty()) {
|
||||||
|
*optimized_graph = item.graph;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool already_optimized = false;
|
||||||
|
for (const auto& optimizer : optimizers) {
|
||||||
if (!already_optimized) {
|
if (!already_optimized) {
|
||||||
return layout_optimizer.Optimize(nullptr, item, optimized_graph);
|
TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
|
||||||
|
already_optimized = true;
|
||||||
} else {
|
} else {
|
||||||
GrapplerItem optimized_item = item;
|
GrapplerItem optimized_item = item;
|
||||||
optimized_item.graph = *optimized_graph;
|
optimized_item.graph = *optimized_graph;
|
||||||
return layout_optimizer.Optimize(nullptr, optimized_item,
|
TF_RETURN_IF_ERROR(
|
||||||
optimized_graph);
|
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer {
|
|||||||
const GraphDef& optimized_graph, double result) override;
|
const GraphDef& optimized_graph, double result) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
|
||||||
RewriterConfig cfg_;
|
RewriterConfig cfg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework";
|
|||||||
message RewriterConfig {
|
message RewriterConfig {
|
||||||
bool optimize_tensor_layout = 1;
|
bool optimize_tensor_layout = 1;
|
||||||
bool disable_model_pruning = 2;
|
bool disable_model_pruning = 2;
|
||||||
|
bool constant_folding = 3;
|
||||||
|
// If non-empty, will use this as an alternative way to specify a list of
|
||||||
|
// optimizations to turn on and the order of the optimizations.
|
||||||
|
repeated string optimizers = 100;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user