[mlir] Simplify DDR matching patterns with equal operands for operators.
This https://reviews.llvm.org/D89254 diff introduced implicit matching between same name arguments. Modify usages accordingly. PiperOrigin-RevId: 338090110 Change-Id: Iec3bc6d712b0d4fc652c788e9e7d55945b922e46
This commit is contained in:
parent
51c47283e7
commit
961803999a
@ -18,15 +18,13 @@ limitations under the License.
|
||||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
||||
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Canonicalization patterns.
|
||||
|
||||
def DynamicBroadcastToOwnShape_1 : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp:$op $arg0,
|
||||
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
|
||||
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
|
||||
(HLO_DynamicBroadcastInDimOp:$op $x,
|
||||
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr),
|
||||
(replaceWithValue $x)>;
|
||||
def DynamicBroadcastToOwnShape_2 : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr),
|
||||
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
|
||||
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
||||
(replaceWithValue $x)>;
|
||||
|
||||
|
@ -238,8 +238,6 @@ def eliminate_dq_q_pairs : Pat<
|
||||
[(NotFromQuantOpOrSameQuantType $in, $qt)]>;
|
||||
|
||||
|
||||
// Constraint that makes sure both operands are the same operands.
|
||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
|
||||
// Checks if the operand has rank == n
|
||||
@ -251,28 +249,26 @@ def MatchHardSwishPattern1 : Pat<
|
||||
(TFL_MulOp
|
||||
(TFL_MulOp
|
||||
$x, (TFL_AddOp
|
||||
$y,
|
||||
$x,
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
|
||||
TFL_AF_Relu6),
|
||||
TFL_AF_None),
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
|
||||
TFL_AF_None),
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
(TFL_HardSwishOp $x)>;
|
||||
|
||||
def MatchHardSwishPattern2 : Pat<
|
||||
(TFL_MulOp
|
||||
$x,
|
||||
(TFL_MulOp
|
||||
(TFL_AddOp
|
||||
$y,
|
||||
$x,
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
|
||||
TFL_AF_Relu6),
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
(TFL_HardSwishOp $x)>;
|
||||
|
||||
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
|
||||
// incorrect placement in the quantization aware training.
|
||||
@ -281,14 +277,13 @@ def MatchHardSwishQuantized : Pat<
|
||||
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
|
||||
(TFL_MulOp
|
||||
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
|
||||
$y,
|
||||
$x,
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
|
||||
TFL_AF_Relu6), $qattr2)),
|
||||
TFL_AF_None), $qattr1)),
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
|
||||
TFL_AF_None),
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
(TFL_HardSwishOp $x)>;
|
||||
|
||||
// Constraint that the attribute value is less than 'n'
|
||||
class ConstDoubleValueLessThan<string n> : Constraint<
|
||||
@ -309,47 +304,44 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
// Mul->Rsqrt->Sum->Square Or
|
||||
// Div->sqrt->Sum->Square
|
||||
def L2NormalizePattern1#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(FirstOp $x,
|
||||
(SecondOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp:$sq_op $square_operand),
|
||||
(TFL_SquareOp:$sq_op $x),
|
||||
(ConstantOp I32ElementsAttr:$axis),
|
||||
$keep_dims)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand),
|
||||
(L2NormValidReduceIndex $sq_op, $axis)]>;
|
||||
(TFL_L2NormalizationOp $x, TFL_AF_None),
|
||||
[(L2NormValidReduceIndex $sq_op, $axis)]>;
|
||||
|
||||
// Below patterns for L2Normalize when there is an Add or Maximum
|
||||
// adding or clamping to a small constant scalar.
|
||||
def L2NormalizePattern2#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(FirstOp $x,
|
||||
(SecondOp
|
||||
(TFL_AddOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp:$sq_op $square_operand),
|
||||
(TFL_SquareOp:$sq_op $x),
|
||||
(ConstantOp I32ElementsAttr:$axis),
|
||||
$keep_dims),
|
||||
(ConstantOp $epsilon), TFL_AF_None)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand),
|
||||
(L2NormValidReduceIndex $sq_op, $axis),
|
||||
(TFL_L2NormalizationOp $x, TFL_AF_None),
|
||||
[(L2NormValidReduceIndex $sq_op, $axis),
|
||||
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
|
||||
|
||||
def L2NormalizePattern3#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(FirstOp $x,
|
||||
(SecondOp
|
||||
(TFL_MaximumOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp:$sq_op $square_operand),
|
||||
(TFL_SquareOp:$sq_op $x),
|
||||
(ConstantOp I32ElementsAttr:$axis),
|
||||
$keep_dims),
|
||||
(ConstantOp $epsilon))),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand),
|
||||
(L2NormValidReduceIndex $sq_op, $axis),
|
||||
(TFL_L2NormalizationOp $x, TFL_AF_None),
|
||||
[(L2NormValidReduceIndex $sq_op, $axis),
|
||||
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
|
||||
|
||||
}
|
||||
@ -521,12 +513,11 @@ def MatchRelu1Pattern2 : Pat<
|
||||
|
||||
def MatchLeakyRelu : Pat<
|
||||
(TFL_MaximumOp
|
||||
(TFL_MulOp:$mul_out $input1,
|
||||
(TFL_MulOp:$mul_out $x,
|
||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||
$input2),
|
||||
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
|
||||
$x),
|
||||
(TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
|
||||
[(ConstDoubleValueLessThan<"1"> $alpha),
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
|
||||
def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
|
||||
@ -542,15 +533,14 @@ def PReluAlphaRankCheck : Constraint<
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def MatchPRelu : Pat<
|
||||
(TFL_AddOp
|
||||
(TFL_ReluOp:$relu_out $input1),
|
||||
(TFL_ReluOp:$relu_out $x),
|
||||
(TFL_MulOp:$mul_out
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
|
||||
$neg_alpha,
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
|
||||
[(PReluAlphaRankCheck $neg_alpha, $x),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)]>;
|
||||
|
@ -209,10 +209,8 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
|
||||
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
|
||||
(TF_ReshapeOp $arg, $shape)>;
|
||||
|
||||
def IsSame : Constraint<CPred<"$0 == $1">>;
|
||||
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
|
||||
(replaceWithValue $arg0),
|
||||
[(IsSame $arg0, $arg1)]>;
|
||||
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $x, (TF_ShapeOp $x)),
|
||||
(replaceWithValue $x)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Select op patterns.
|
||||
|
@ -23,10 +23,6 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// Constraint that makes sure both operands are the same operands.
|
||||
// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed.
|
||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Checks if the operand0's rank is one less than operand1's rank.
|
||||
def PReluAlphaRankCheck : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == "
|
||||
@ -36,13 +32,12 @@ def PReluAlphaRankCheck : Constraint<
|
||||
// PReLU pattern from Keras:
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def : Pat<(TF_AddV2Op
|
||||
(TF_ReluOp:$relu_out $input1),
|
||||
(TF_ReluOp:$relu_out $x),
|
||||
(TF_MulOp:$mul_out
|
||||
(TF_ReluOp (TF_NegOp:$input_neg_out $input2)),
|
||||
(TF_ReluOp (TF_NegOp:$input_neg_out $x)),
|
||||
$neg_alpha)),
|
||||
(TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(TFJS_PReluOp $x, (TF_NegOp $neg_alpha)),
|
||||
[(PReluAlphaRankCheck $neg_alpha, $x),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user