[Grappler] Add IsAnyMatMul and restrict IsMatMul to just 'MatMul' op type

PiperOrigin-RevId: 243115109
This commit is contained in:
Eugene Zhulenev 2019-04-11 12:04:37 -07:00 committed by TensorFlower Gardener
parent 1b206c1a31
commit ceeea9c916
4 changed files with 10 additions and 7 deletions

View File

@ -47,6 +47,12 @@ bool IsAnyDiv(const NodeDef& node) {
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
}
bool IsAnyMatMul(const NodeDef& node) {
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
IsQuantizedMatMul(node);
}
bool IsAnyMax(const NodeDef& node) {
const auto& op = node.op();
return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
@ -301,11 +307,7 @@ bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
bool IsMatMul(const NodeDef& node) {
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
IsQuantizedMatMul(node);
}
bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
bool IsMax(const NodeDef& node) { return node.op() == "Max"; }

View File

@ -28,6 +28,7 @@ bool IsAll(const NodeDef& node);
bool IsAngle(const NodeDef& node);
bool IsAny(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node);
bool IsAnyMatMul(const NodeDef& node);
bool IsAnyMax(const NodeDef& node);
bool IsAnyMaxPool(const NodeDef& node);
bool IsAnyMin(const NodeDef& node);

View File

@ -2178,7 +2178,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
~FoldTransposeIntoMatMul() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsMatMul(*node);
return IsAnyMatMul(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {

View File

@ -2639,7 +2639,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
GraphDef* optimized_graph, NodeDef* node, bool* success) {
*success = false;
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_matmul = IsAnyMatMul(*node);
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node);