[Grappler] Add IsAnyMatMul and restrict IsMatMul to just 'MatMul' op type
PiperOrigin-RevId: 243115109
This commit is contained in:
parent
1b206c1a31
commit
ceeea9c916
@ -47,6 +47,12 @@ bool IsAnyDiv(const NodeDef& node) {
|
|||||||
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
|
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsAnyMatMul(const NodeDef& node) {
|
||||||
|
const auto& op = node.op();
|
||||||
|
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
|
||||||
|
IsQuantizedMatMul(node);
|
||||||
|
}
|
||||||
|
|
||||||
bool IsAnyMax(const NodeDef& node) {
|
bool IsAnyMax(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
|
return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
|
||||||
@ -301,11 +307,7 @@ bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
|
|||||||
|
|
||||||
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
|
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
|
||||||
|
|
||||||
bool IsMatMul(const NodeDef& node) {
|
bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
|
||||||
const auto& op = node.op();
|
|
||||||
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
|
|
||||||
IsQuantizedMatMul(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
|
bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ bool IsAll(const NodeDef& node);
|
|||||||
bool IsAngle(const NodeDef& node);
|
bool IsAngle(const NodeDef& node);
|
||||||
bool IsAny(const NodeDef& node);
|
bool IsAny(const NodeDef& node);
|
||||||
bool IsAnyDiv(const NodeDef& node);
|
bool IsAnyDiv(const NodeDef& node);
|
||||||
|
bool IsAnyMatMul(const NodeDef& node);
|
||||||
bool IsAnyMax(const NodeDef& node);
|
bool IsAnyMax(const NodeDef& node);
|
||||||
bool IsAnyMaxPool(const NodeDef& node);
|
bool IsAnyMaxPool(const NodeDef& node);
|
||||||
bool IsAnyMin(const NodeDef& node);
|
bool IsAnyMin(const NodeDef& node);
|
||||||
|
@ -2178,7 +2178,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
|||||||
~FoldTransposeIntoMatMul() override = default;
|
~FoldTransposeIntoMatMul() override = default;
|
||||||
|
|
||||||
bool IsSupported(const NodeDef* node) const override {
|
bool IsSupported(const NodeDef* node) const override {
|
||||||
return IsMatMul(*node);
|
return IsAnyMatMul(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||||
|
@ -2639,7 +2639,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
|||||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
||||||
*success = false;
|
*success = false;
|
||||||
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
|
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
|
||||||
const bool is_matmul = IsMatMul(*node);
|
const bool is_matmul = IsAnyMatMul(*node);
|
||||||
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
|
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
|
||||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||||
const bool is_sub = IsSub(*node);
|
const bool is_sub = IsSub(*node);
|
||||||
|
Loading…
Reference in New Issue
Block a user