diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 4e41c2bb129..bd96e2b33cc 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -17,6 +17,7 @@ filegroup( srcs = glob( [ "*_optimizer.*", + "constant_folding.*", "model_pruner.*", "graph_rewriter.*", ], @@ -175,6 +176,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":constant_folding", ":graph_optimizer", ":layout_optimizer", ":model_pruner", diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7cddedef2e2..8f79c55810b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase { Tensor* tensor) override { Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - tensor_proto.DebugString()); + return errors::InvalidArgument("Cannot parse tensor from tensor_proto."); } *tensor = parsed; return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 44a1f5bab92..d82d5a469de 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/lib/core/status.h" @@ -21,25 +23,64 @@ limitations under the License. namespace tensorflow { namespace grappler { +std::unique_ptr MetaOptimizer::NewOptimizer( + const string& optimizer) { + VLOG(1) << "Adding graph optimization pass: " << optimizer; + std::unique_ptr graph_optimizer; + if (optimizer == "pruning") { + graph_optimizer.reset(new ModelPruner()); + } + if (optimizer == "constfold") { + graph_optimizer.reset(new ConstantFolding()); + } + if (optimizer == "layout") { + graph_optimizer.reset(new LayoutOptimizer()); + } + return graph_optimizer; +} + Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { - bool already_optimized = false; - if (!cfg_.disable_model_pruning()) { - already_optimized = true; - ModelPruner pruner; - TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph)); + std::vector> optimizers; + if (cfg_.optimizers().empty()) { + if (!cfg_.disable_model_pruning()) { + optimizers.push_back(std::unique_ptr(new ModelPruner())); + } + if (cfg_.constant_folding()) { + optimizers.push_back( + std::unique_ptr(new ConstantFolding())); + } + if (cfg_.optimize_tensor_layout()) { + optimizers.push_back( + std::unique_ptr(new LayoutOptimizer())); + } + } else { + std::set avaliable_optimizers = {"pruning", "constfold", "layout"}; + for (const auto& optimizer : cfg_.optimizers()) { + if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) { + optimizers.push_back(NewOptimizer(optimizer)); + } + } } - if (cfg_.optimize_tensor_layout()) { - LayoutOptimizer layout_optimizer; + + if (optimizers.empty()) { + *optimized_graph = item.graph; + return Status::OK(); + } + + bool already_optimized = false; + for (const auto& optimizer : optimizers) { if (!already_optimized) { - return layout_optimizer.Optimize(nullptr, item, optimized_graph); + TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph)); + already_optimized = true; } else { GrapplerItem optimized_item = item; optimized_item.graph = *optimized_graph; - return layout_optimizer.Optimize(nullptr, optimized_item, - optimized_graph); + TF_RETURN_IF_ERROR( + optimizer->Optimize(nullptr, optimized_item, optimized_graph)); } } + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index d7ff03f5907..9def2cd711f 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: + std::unique_ptr NewOptimizer(const string& optimizer); RewriterConfig cfg_; }; diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index aef69461d88..6e9eff62254 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework"; message RewriterConfig { bool optimize_tensor_layout = 1; bool disable_model_pruning = 2; + bool constant_folding = 3; + // If non-empty, will use this as an alternative way to specify a list of + // optimizations to turn on and the order of the optimizations. + repeated string optimizers = 100; }