[Grappler] Add IsAnyMatMul and restrict IsMatMul to just 'MatMul' op type

PiperOrigin-RevId: 243115109
This commit is contained in:
Eugene Zhulenev 2019-04-11 12:04:37 -07:00 committed by TensorFlower Gardener
parent 1b206c1a31
commit ceeea9c916
4 changed files with 10 additions and 7 deletions

View File

@ -47,6 +47,12 @@ bool IsAnyDiv(const NodeDef& node) {
node.op() == "FloorDiv" || node.op() == "TruncateDiv"; node.op() == "FloorDiv" || node.op() == "TruncateDiv";
} }
bool IsAnyMatMul(const NodeDef& node) {
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
IsQuantizedMatMul(node);
}
bool IsAnyMax(const NodeDef& node) { bool IsAnyMax(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax"; return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
@ -301,11 +307,7 @@ bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
bool IsMatMul(const NodeDef& node) { bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
IsQuantizedMatMul(node);
}
bool IsMax(const NodeDef& node) { return node.op() == "Max"; } bool IsMax(const NodeDef& node) { return node.op() == "Max"; }

View File

@ -28,6 +28,7 @@ bool IsAll(const NodeDef& node);
bool IsAngle(const NodeDef& node); bool IsAngle(const NodeDef& node);
bool IsAny(const NodeDef& node); bool IsAny(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node); bool IsAnyDiv(const NodeDef& node);
bool IsAnyMatMul(const NodeDef& node);
bool IsAnyMax(const NodeDef& node); bool IsAnyMax(const NodeDef& node);
bool IsAnyMaxPool(const NodeDef& node); bool IsAnyMaxPool(const NodeDef& node);
bool IsAnyMin(const NodeDef& node); bool IsAnyMin(const NodeDef& node);

View File

@ -2178,7 +2178,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
~FoldTransposeIntoMatMul() override = default; ~FoldTransposeIntoMatMul() override = default;
bool IsSupported(const NodeDef* node) const override { bool IsSupported(const NodeDef* node) const override {
return IsMatMul(*node); return IsAnyMatMul(*node);
} }
Status TrySimplify(NodeDef* node, string* simplified_node_name) override { Status TrySimplify(NodeDef* node, string* simplified_node_name) override {

View File

@ -2639,7 +2639,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
GraphDef* optimized_graph, NodeDef* node, bool* success) { GraphDef* optimized_graph, NodeDef* node, bool* success) {
*success = false; *success = false;
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node); const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node); const bool is_matmul = IsAnyMatMul(*node);
const bool is_quantized_matmul = IsQuantizedMatMul(*node); const bool is_quantized_matmul = IsQuantizedMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node); const bool is_sub = IsSub(*node);