From ceeea9c91647e21cb846599c0d326f22de98c534 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 11 Apr 2019 12:04:37 -0700 Subject: [PATCH] [Grappler] Add IsAnyMatMul and restrict IsMatMul to just 'MatMul' op type PiperOrigin-RevId: 243115109 --- tensorflow/core/grappler/op_types.cc | 12 +++++++----- tensorflow/core/grappler/op_types.h | 1 + .../core/grappler/optimizers/arithmetic_optimizer.cc | 2 +- .../core/grappler/optimizers/constant_folding.cc | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 12542947b01..bcac0e79a24 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -47,6 +47,12 @@ bool IsAnyDiv(const NodeDef& node) { node.op() == "FloorDiv" || node.op() == "TruncateDiv"; } +bool IsAnyMatMul(const NodeDef& node) { + const auto& op = node.op(); + return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" || + IsQuantizedMatMul(node); +} + bool IsAnyMax(const NodeDef& node) { const auto& op = node.op(); return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax"; @@ -301,11 +307,7 @@ bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; } bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } -bool IsMatMul(const NodeDef& node) { - const auto& op = node.op(); - return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" || - IsQuantizedMatMul(node); -} +bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; } bool IsMax(const NodeDef& node) { return node.op() == "Max"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 643ecd4e6eb..d0562c32e4c 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -28,6 +28,7 @@ bool IsAll(const NodeDef& node); bool IsAngle(const NodeDef& node); bool IsAny(const NodeDef& node); bool IsAnyDiv(const NodeDef& node); +bool IsAnyMatMul(const NodeDef& node); bool IsAnyMax(const NodeDef& node); bool IsAnyMaxPool(const NodeDef& node); bool IsAnyMin(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 7dc62fe54c1..afadb539fb8 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2178,7 +2178,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { ~FoldTransposeIntoMatMul() override = default; bool IsSupported(const NodeDef* node) const override { - return IsMatMul(*node); + return IsAnyMatMul(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 12fba979a96..670bd04e821 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2639,7 +2639,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( GraphDef* optimized_graph, NodeDef* node, bool* success) { *success = false; const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node); - const bool is_matmul = IsMatMul(*node); + const bool is_matmul = IsAnyMatMul(*node); const bool is_quantized_matmul = IsQuantizedMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_sub = IsSub(*node);