From b0e0044dcf0774ef312bec59294d4b53170e6ffa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Nov 2020 02:32:08 -0800 Subject: [PATCH] Add canonicalizer for Reshape(Broadcast(X)) pattern when it is an identity sequence PiperOrigin-RevId: 343251257 Change-Id: I4bb27e1132f40b6527dde80ba51c37f97bf8a6f5 --- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 1 + .../mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc | 6 ++++++ .../hlo/lib/Dialect/mhlo/IR/hlo_patterns.td | 12 ++++++++++++ .../compiler/mlir/hlo/tests/canonicalize.mlir | 18 ++++++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 0ebf3ac6bd1..dd6d3b3fbaf 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape", let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 389c5794c91..2b92afe956b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return {}; } +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert( + context); +} + //===----------------------------------------------------------------------===// // Case Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td index 776732b178b..01564b86381 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat< def ShapeOfDynamicReshape : Pat< (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), (replaceWithValue $shape)>; + +def HasSameType : Constraint>; + +def IdentityBroadcastReshape : Pat< + (HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)), + (replaceWithValue $input), + [(HasSameType $input, $op)]>; + +def IdentityBroadcastInDimReshape : Pat< + (HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)), + (replaceWithValue $input), + [(HasSameType $input, $op)]>; diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 8470f363fcb..41eedeeabe5 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> { // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] // CHECK-SAME: ]> : tensor<4x5xi32> } + +// CHECK-LABEL: @identity_broadcast_reshape +func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast"(%arg0) { + broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +} + +// CHECK-LABEL: @identity_broadcast_in_dim_reshape +func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +}