Add canonicalizer for Reshape(Broadcast(X)) pattern when it is an identity sequence
PiperOrigin-RevId: 343251257 Change-Id: I4bb27e1132f40b6527dde80ba51c37f97bf8a6f5
This commit is contained in:
parent
7d09898aee
commit
b0e0044dcf
@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
||||
|
||||
let results = (outs HLO_StaticShapeTensor);
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return {};
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
||||
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def IdentityBroadcastReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
||||
def IdentityBroadcastInDimReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
|
||||
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_reshape
|
||||
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast"(%arg0) {
|
||||
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
|
||||
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user