diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td index b8b6cb80fba..bdb3e3cf490 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -18,15 +18,13 @@ limitations under the License. include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" -def EqualBinaryOperands : Constraint>; - // Canonicalization patterns. def DynamicBroadcastToOwnShape_1 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $arg0, - (Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), - (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + (HLO_DynamicBroadcastInDimOp:$op $x, + (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr), + (replaceWithValue $x)>; def DynamicBroadcastToOwnShape_2 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr), - (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), + (replaceWithValue $x)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 653c33ea9df..57925663d74 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -238,8 +238,6 @@ def eliminate_dq_q_pairs : Pat< [(NotFromQuantOpOrSameQuantType $in, $qt)]>; -// Constraint that makes sure both operands are the same operands. -def EqualOperands : Constraint>; // Checks if the operand has rank == n @@ -251,28 +249,26 @@ def MatchHardSwishPattern1 : Pat< (TFL_MulOp (TFL_MulOp $x, (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), TFL_AF_None), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; def MatchHardSwishPattern2 : Pat< (TFL_MulOp $x, (TFL_MulOp (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; // Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to // incorrect placement in the quantization aware training. @@ -281,14 +277,13 @@ def MatchHardSwishQuantized : Pat< (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp (TFL_MulOp $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), $qattr2)), TFL_AF_None), $qattr1)), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< @@ -309,47 +304,44 @@ multiclass L2NormalizePatterns { // Mul->Rsqrt->Sum->Square Or // Div->sqrt->Sum->Square def L2NormalizePattern1#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims)), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis)]>; + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis)]>; // Below patterns for L2Normalize when there is an Add or Maximum // adding or clamping to a small constant scalar. def L2NormalizePattern2#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_AddOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims), (ConstantOp $epsilon), TFL_AF_None)), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis), + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis), (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; def L2NormalizePattern3#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_MaximumOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims), (ConstantOp $epsilon))), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis), + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis), (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; } @@ -521,12 +513,11 @@ def MatchRelu1Pattern2 : Pat< def MatchLeakyRelu : Pat< (TFL_MaximumOp - (TFL_MulOp:$mul_out $input1, + (TFL_MulOp:$mul_out $x, (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), - $input2), - (TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha), + $x), + (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha), [(ConstDoubleValueLessThan<"1"> $alpha), - (EqualOperands $input1, $input2), (HasOneUse $mul_out)]>; def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input), @@ -542,15 +533,14 @@ def PReluAlphaRankCheck : Constraint< // f(x) = Relu(x) + (-alpha * Relu(-x)) def MatchPRelu : Pat< (TFL_AddOp - (TFL_ReluOp:$relu_out $input1), + (TFL_ReluOp:$relu_out $x), (TFL_MulOp:$mul_out - (TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)), + (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)), $neg_alpha, TFL_AF_None), TFL_AF_None), - (TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)), - [(EqualOperands $input1, $input2), - (PReluAlphaRankCheck $neg_alpha, $input1), + (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)), + [(PReluAlphaRankCheck $neg_alpha, $x), (HasOneUse $relu_out), (HasOneUse $mul_out), (HasOneUse $input_neg_out)]>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index d5b7eb7a739..945573aa978 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -209,10 +209,8 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), (TF_ReshapeOp $arg, $shape)>; -def IsSame : Constraint>; -def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)), - (replaceWithValue $arg0), - [(IsSame $arg0, $arg1)]>; +def ReshapeToSelfShape : Pat<(TF_ReshapeOp $x, (TF_ShapeOp $x)), + (replaceWithValue $x)>; //===----------------------------------------------------------------------===// // Select op patterns. diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td index c5a059e5b6b..c17939fd962 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td @@ -23,10 +23,6 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Checks if the value has only one user. def HasOneUse : Constraint>; -// Constraint that makes sure both operands are the same operands. -// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed. -def EqualOperands : Constraint>; - // Checks if the operand0's rank is one less than operand1's rank. def PReluAlphaRankCheck : Constraint< CPred<"$0.getType().cast().getRank() == " @@ -36,13 +32,12 @@ def PReluAlphaRankCheck : Constraint< // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) def : Pat<(TF_AddV2Op - (TF_ReluOp:$relu_out $input1), + (TF_ReluOp:$relu_out $x), (TF_MulOp:$mul_out - (TF_ReluOp (TF_NegOp:$input_neg_out $input2)), + (TF_ReluOp (TF_NegOp:$input_neg_out $x)), $neg_alpha)), - (TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)), - [(EqualOperands $input1, $input2), - (PReluAlphaRankCheck $neg_alpha, $input1), + (TFJS_PReluOp $x, (TF_NegOp $neg_alpha)), + [(PReluAlphaRankCheck $neg_alpha, $x), (HasOneUse $relu_out), (HasOneUse $mul_out), (HasOneUse $input_neg_out)