[mlir] Simplify DDR matching patterns with equal operands for operators.

This https://reviews.llvm.org/D89254 diff introduced implicit matching between same name arguments. Modify usages accordingly.

PiperOrigin-RevId: 338090110
Change-Id: Iec3bc6d712b0d4fc652c788e9e7d55945b922e46
This commit is contained in:
Roman Dzhabarov 2020-10-20 10:46:19 -07:00 committed by TensorFlower Gardener
parent 51c47283e7
commit 961803999a
4 changed files with 36 additions and 55 deletions

View File

@ -18,15 +18,13 @@ limitations under the License.
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns.
def DynamicBroadcastToOwnShape_1 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
(HLO_DynamicBroadcastInDimOp:$op $x,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr),
(replaceWithValue $x)>;
def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $x)>;

View File

@ -238,8 +238,6 @@ def eliminate_dq_q_pairs : Pat<
[(NotFromQuantOpOrSameQuantType $in, $qt)]>;
// Constraint that makes sure both operands are the same operands.
def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n
@ -251,28 +249,26 @@ def MatchHardSwishPattern1 : Pat<
(TFL_MulOp
(TFL_MulOp
$x, (TFL_AddOp
$y,
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
TFL_AF_None),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
(TFL_HardSwishOp $x)>;
def MatchHardSwishPattern2 : Pat<
(TFL_MulOp
$x,
(TFL_MulOp
(TFL_AddOp
$y,
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
(TFL_HardSwishOp $x)>;
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
// incorrect placement in the quantization aware training.
@ -281,14 +277,13 @@ def MatchHardSwishQuantized : Pat<
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
(TFL_MulOp
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
$y,
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6), $qattr2)),
TFL_AF_None), $qattr1)),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
(TFL_HardSwishOp $x)>;
// Constraint that the attribute value is less than 'n'
class ConstDoubleValueLessThan<string n> : Constraint<
@ -309,47 +304,44 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
// Mul->Rsqrt->Sum->Square Or
// Div->sqrt->Sum->Square
def L2NormalizePattern1#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(FirstOp $x,
(SecondOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $square_operand),
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(L2NormValidReduceIndex $sq_op, $axis)]>;
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis)]>;
// Below patterns for L2Normalize when there is an Add or Maximum
// adding or clamping to a small constant scalar.
def L2NormalizePattern2#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(FirstOp $x,
(SecondOp
(TFL_AddOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $square_operand),
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims),
(ConstantOp $epsilon), TFL_AF_None)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(L2NormValidReduceIndex $sq_op, $axis),
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
def L2NormalizePattern3#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(FirstOp $x,
(SecondOp
(TFL_MaximumOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $square_operand),
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims),
(ConstantOp $epsilon))),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(L2NormValidReduceIndex $sq_op, $axis),
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
}
@ -521,12 +513,11 @@ def MatchRelu1Pattern2 : Pat<
def MatchLeakyRelu : Pat<
(TFL_MaximumOp
(TFL_MulOp:$mul_out $input1,
(TFL_MulOp:$mul_out $x,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
$input2),
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
$x),
(TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
[(ConstDoubleValueLessThan<"1"> $alpha),
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
@ -542,15 +533,14 @@ def PReluAlphaRankCheck : Constraint<
// f(x) = Relu(x) + (-alpha * Relu(-x))
def MatchPRelu : Pat<
(TFL_AddOp
(TFL_ReluOp:$relu_out $input1),
(TFL_ReluOp:$relu_out $x),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
(TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
[(PReluAlphaRankCheck $neg_alpha, $x),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;

View File

@ -209,10 +209,8 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
(TF_ReshapeOp $arg, $shape)>;
def IsSame : Constraint<CPred<"$0 == $1">>;
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
(replaceWithValue $arg0),
[(IsSame $arg0, $arg1)]>;
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $x, (TF_ShapeOp $x)),
(replaceWithValue $x)>;
//===----------------------------------------------------------------------===//
// Select op patterns.

View File

@ -23,10 +23,6 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Constraint that makes sure both operands are the same operands.
// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed.
def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "
@ -36,13 +32,12 @@ def PReluAlphaRankCheck : Constraint<
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def : Pat<(TF_AddV2Op
(TF_ReluOp:$relu_out $input1),
(TF_ReluOp:$relu_out $x),
(TF_MulOp:$mul_out
(TF_ReluOp (TF_NegOp:$input_neg_out $input2)),
(TF_ReluOp (TF_NegOp:$input_neg_out $x)),
$neg_alpha)),
(TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(TFJS_PReluOp $x, (TF_NegOp $neg_alpha)),
[(PReluAlphaRankCheck $neg_alpha, $x),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)