From 7cdcd0cca2a97c45c634f45b0ace0771ce5a5498 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 2 Jun 2017 14:23:14 -0700 Subject: [PATCH] Filter more op types that don't benefit from constant folding. PiperOrigin-RevId: 157875168 --- .../grappler/optimizers/constant_folding.cc | 35 +++++++++++-------- .../grappler/optimizers/constant_folding.h | 12 ++----- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ea5bfe164b3..c2df76e4315 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -101,6 +101,11 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) { } } // namespace +ConstantFolding::ConstantFolding() { + ops_to_preserve_ = + std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader"); +} + Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically()); @@ -184,28 +189,19 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { } bool ConstantFolding::IsFoldable(const NodeDef& node) const { - DeviceTypeVector device_types; - auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, - &device_types); - if (!status.ok()) { - return false; - } - // Only fold ops with a CPU implementation available. - if (device_types[0] != DeviceType(DEVICE_CPU)) { - return false; - } - + // Skips nodes that must be preserved, and op_types that don't benefit from + // folding if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; } - - if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) { + std::cmatch match; + if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) { return false; } // Don't fold stateful ops such as TruncatedNormal. const OpDef* op_def = nullptr; - status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } @@ -217,6 +213,17 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + DeviceTypeVector device_types; + status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, + &device_types); + if (!status.ok()) { + return false; + } + // Only fold ops with a CPU implementation available. + if (device_types[0] != DeviceType(DEVICE_CPU)) { + return false; + } + // Folding not applicable to ops with no inputs. if (node.input().empty()) { return false; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9689e97a123..cb9729ef1ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ +#include #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" @@ -29,7 +30,7 @@ const char kConstantFoldingConst[] = "ConstantFolding"; // Contant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: - ConstantFolding() {} + ConstantFolding(); ~ConstantFolding() override {} @@ -66,14 +67,7 @@ class ConstantFolding : public GraphOptimizer { GraphDef graph_; std::unique_ptr node_map_; std::set nodes_to_preserve_; - std::set ops_to_preserve_ = {"Save", - "SaveV2", - "SaveSlices", - "Restore", - "RestoreV2", - "RestoreSlice", - "PlaceholderWithDefault", - "Const"}; + std::regex ops_to_preserve_; }; } // end namespace grappler