Filter more op types that don't benefit from constant folding.
PiperOrigin-RevId: 157875168
This commit is contained in:
parent
366990d92d
commit
7cdcd0cca2
@ -101,6 +101,11 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ConstantFolding::ConstantFolding() {
|
||||
ops_to_preserve_ =
|
||||
std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader");
|
||||
}
|
||||
|
||||
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically());
|
||||
@ -184,28 +189,19 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
||||
}
|
||||
|
||||
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
DeviceTypeVector device_types;
|
||||
auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
|
||||
&device_types);
|
||||
if (!status.ok()) {
|
||||
return false;
|
||||
}
|
||||
// Only fold ops with a CPU implementation available.
|
||||
if (device_types[0] != DeviceType(DEVICE_CPU)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Skips nodes that must be preserved, and op_types that don't benefit from
|
||||
// folding
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) {
|
||||
std::cmatch match;
|
||||
if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't fold stateful ops such as TruncatedNormal.
|
||||
const OpDef* op_def = nullptr;
|
||||
status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
if (!status.ok()) {
|
||||
return false;
|
||||
}
|
||||
@ -217,6 +213,17 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
DeviceTypeVector device_types;
|
||||
status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
|
||||
&device_types);
|
||||
if (!status.ok()) {
|
||||
return false;
|
||||
}
|
||||
// Only fold ops with a CPU implementation available.
|
||||
if (device_types[0] != DeviceType(DEVICE_CPU)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Folding not applicable to ops with no inputs.
|
||||
if (node.input().empty()) {
|
||||
return false;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
|
||||
|
||||
#include <regex>
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
@ -29,7 +30,7 @@ const char kConstantFoldingConst[] = "ConstantFolding";
|
||||
// Contant folding optimization for a graph.
|
||||
class ConstantFolding : public GraphOptimizer {
|
||||
public:
|
||||
ConstantFolding() {}
|
||||
ConstantFolding();
|
||||
|
||||
~ConstantFolding() override {}
|
||||
|
||||
@ -66,14 +67,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
GraphDef graph_;
|
||||
std::unique_ptr<NodeMap> node_map_;
|
||||
std::set<string> nodes_to_preserve_;
|
||||
std::set<string> ops_to_preserve_ = {"Save",
|
||||
"SaveV2",
|
||||
"SaveSlices",
|
||||
"Restore",
|
||||
"RestoreV2",
|
||||
"RestoreSlice",
|
||||
"PlaceholderWithDefault",
|
||||
"Const"};
|
||||
std::regex ops_to_preserve_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
Loading…
x
Reference in New Issue
Block a user