Filter more op types that don't benefit from constant folding.

PiperOrigin-RevId: 157875168
This commit is contained in:
Benoit Steiner 2017-06-02 14:23:14 -07:00 committed by TensorFlower Gardener
parent 366990d92d
commit 7cdcd0cca2
2 changed files with 24 additions and 23 deletions

View File

@ -101,6 +101,11 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
}
} // namespace
ConstantFolding::ConstantFolding() {
ops_to_preserve_ =
std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader");
}
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically());
@ -184,28 +189,19 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
}
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
DeviceTypeVector device_types;
auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
&device_types);
if (!status.ok()) {
return false;
}
// Only fold ops with a CPU implementation available.
if (device_types[0] != DeviceType(DEVICE_CPU)) {
return false;
}
// Skips nodes that must be preserved, and op_types that don't benefit from
// folding
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
}
if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) {
std::cmatch match;
if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) {
return false;
}
// Don't fold stateful ops such as TruncatedNormal.
const OpDef* op_def = nullptr;
status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
if (!status.ok()) {
return false;
}
@ -217,6 +213,17 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false;
}
DeviceTypeVector device_types;
status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
&device_types);
if (!status.ok()) {
return false;
}
// Only fold ops with a CPU implementation available.
if (device_types[0] != DeviceType(DEVICE_CPU)) {
return false;
}
// Folding not applicable to ops with no inputs.
if (node.input().empty()) {
return false;

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
#include <regex>
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
@ -29,7 +30,7 @@ const char kConstantFoldingConst[] = "ConstantFolding";
// Contant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
public:
ConstantFolding() {}
ConstantFolding();
~ConstantFolding() override {}
@ -66,14 +67,7 @@ class ConstantFolding : public GraphOptimizer {
GraphDef graph_;
std::unique_ptr<NodeMap> node_map_;
std::set<string> nodes_to_preserve_;
std::set<string> ops_to_preserve_ = {"Save",
"SaveV2",
"SaveSlices",
"Restore",
"RestoreV2",
"RestoreSlice",
"PlaceholderWithDefault",
"Const"};
std::regex ops_to_preserve_;
};
} // end namespace grappler