diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index ecee1d681d6..120f3e06380 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -359,3 +359,62 @@ func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +func @reshape_collapse_single_dim + (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> + return %0 : tensor<1x784xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-LABEL: func @reshape_collapse_single_dim +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> + return %0 : tensor<2x4x3xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-LABEL: func @reshape_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> + return %0 : tensor<2x4x2xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-LABEL: func @reshape_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> + return %0 : tensor<1x4x2xf32> +} +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_single_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] + +// ----- + +func @reshape_multiple_collapse + (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> + return %0 : tensor<1x4x5x6xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-LABEL: func @reshape_multiple_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index cb6c4d368c2..4de45601296 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -47,14 +47,14 @@ ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) { return b->getArrayAttr(iteratorTypes); } +template +Value getResultValue(Operation* op) { + return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); +} + template ShapedType getXLAOpResultType(Operation* op) { - if (isLHLO) { - return op->getOperand(op->getNumOperands() - 1) - .getType() - .cast(); - } - return op->getResult(0).getType().cast(); + return getResultValue(op).getType().template cast(); } template @@ -206,9 +206,7 @@ class DataMovementOpConverter : public OpConversionPattern { if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); auto operandType = op.operand().getType().template cast(); auto resultType = getXLAOpResultType(op); - if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); - ArrayAttr indexingMapsAttr = - static_cast(*this).getIndexingMapsAttr(op, &rewriter); + ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter); if (!indexingMapsAttr) return failure(); OpBuilder::InsertionGuard linalgOpGuard(rewriter); @@ -241,7 +239,7 @@ class BroadcastInDimConverter using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; - ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) const { + static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) { auto resultType = getXLAOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); @@ -278,29 +276,6 @@ class BroadcastInDimConverter } }; -template -class TransposeConverter - : public DataMovementOpConverter, OpTy, - isLHLO> { - public: - using DataMovementOpConverter, OpTy, - isLHLO>::DataMovementOpConverter; - ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const { - auto resultType = - getXLAOpResultType(op).template cast(); - auto nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.resize(resultType.getRank()); - for (auto permutation : llvm::enumerate(op.permutation())) { - inputExprs[permutation.value().getZExtValue()] = - b->getAffineDimExpr(permutation.index()); - } - return b->getAffineMapArrayAttr( - {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}); - } -}; - /// Pattern for the special case where reshape is adding or removing a dimension /// of size 1. These can be lowered to a linalg.generic op. /// @@ -314,15 +289,33 @@ class TransposeConverter /// can have indexing maps /// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, /// d2)>] + +// TODO(ravishankarm): This pattern needs to be removed. The general reshape +// lowering hits a corner case where the following sequence of operations +// cannot be fused cause the resulting indexing map is not invertible. +// +// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>, +// affine_map<(d0, d1, d2) -> (d2)>] +// : tensor<5x5xf32> into tensor<5x1x5xf32> +// %f = linalg.generic +// {... +// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +// affine_map<(d0, d1, d2) -> (d0, d2)>], +// iterator_types = ["parallel", "parallel", "parallel"]} %r {..} +// : tensor<5x1x5xf32> -> tensor<5x5xf32> +// +// The resolution of this requires a canonicalization on linalg ops where the +// dims of size 1 are removed. This pattern can be removed after that. template class ReshapeAddRemoveDimConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: - using DataMovementOpConverter, - OpTy, isLHLO>::DataMovementOpConverter; + ReshapeAddRemoveDimConverter(MLIRContext* context) + : DataMovementOpConverter, + OpTy, isLHLO>(context, 100) {} - ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const { + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { auto resultType = getXLAOpResultType(op).template cast(); auto operandType = @@ -373,6 +366,106 @@ class ReshapeAddRemoveDimConverter } }; +template +class TransposeConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.resize(resultType.getRank()); + for (auto permutation : llvm::enumerate(op.permutation())) { + inputExprs[permutation.value().getZExtValue()] = + b->getAffineDimExpr(permutation.index()); + } + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } +}; + +// Converts reshape ops that can be proven to be either a collapse of dimensions +// or expansion of dimensions of the operand. +template +class ReshapeOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy reshapeOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!verifyXLAOpBufferOrTensorSemantics(reshapeOp)) + return failure(); + ShapedType operandType = + reshapeOp.operand().getType().template cast(); + ShapedType resultType = getXLAOpResultType(reshapeOp); + + if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + // TODO(ravishankarm): To make this pattern not match the pattern that + // ReshapeAddRemoveDimConverter is for, check that condition here. Remove + // this when ReshapeAddRemoveDimConverter pattern is removed. + if (ReshapeAddRemoveDimConverter::getIndexingMapsAttr( + reshapeOp, &rewriter)) + return failure(); + + // Compute the reassociation maps for the linalg operation. + ArrayRef srcShape = + (operandType.getRank() > resultType.getRank() ? operandType.getShape() + : resultType.getShape()); + ArrayRef dstShape = + (operandType.getRank() > resultType.getRank() ? resultType.getShape() + : operandType.getShape()); + unsigned currSrcDim = 0, currDstDim = 0; + SmallVector, 4> exprs(dstShape.size()); + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in dstShape is not 1, treat subsequent dims in + // srcShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || + dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + exprs[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } else { + return failure(); + } + currDstDim++; + } + if (currSrcDim != srcShape.size()) return failure(); + + SmallVector, 4> reassociationMaps; + for (auto& expr : exprs) reassociationMaps.push_back(expr); + + if (isLHLO) { + Value reshapeBuffer = rewriter.create( + reshapeOp.getLoc(), resultType, args[0], reassociationMaps); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, args[0], reassociationMaps); + } + return success(); + } +}; + class IotaConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -578,8 +671,6 @@ namespace xla_hlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { patterns->insert, - ReshapeAddRemoveDimConverter, - TransposeConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -599,7 +690,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter>(context); + PointwiseToLinalgConverter, + ReshapeAddRemoveDimConverter, + ReshapeOpConverter, + TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() {