Generalize reshape op lowering to cover expansion/collapse of dims.
With addition of linalg.tensor_reshape, the reshape op lowering for both HLO and LHLO can be generalized to cover cases where reshape is an expansion or collapse of dimensions. PiperOrigin-RevId: 308697032 Change-Id: I461508211a0bc86a60f41b656b5e5af5b45eb2f8
This commit is contained in:
parent
a26e04d948
commit
c6eaf562e7
@ -359,3 +359,62 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
|
||||
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_collapse_single_dim
|
||||
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
return %0 : tensor<1x784xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_collapse_single_dim
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
return %0 : tensor<2x4x3xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
|
||||
// CHECK-LABEL: func @reshape_collapse
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
return %0 : tensor<2x4x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-LABEL: func @reshape_expand
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
return %0 : tensor<1x4x2xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @reshape_single_expand
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_multiple_collapse
|
||||
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
return %0 : tensor<1x4x5x6xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
|
||||
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
|
||||
// CHECK-LABEL: func @reshape_multiple_collapse
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
|
||||
|
@ -47,14 +47,14 @@ ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) {
|
||||
return b->getArrayAttr(iteratorTypes);
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
Value getResultValue(Operation* op) {
|
||||
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
ShapedType getXLAOpResultType(Operation* op) {
|
||||
if (isLHLO) {
|
||||
return op->getOperand(op->getNumOperands() - 1)
|
||||
.getType()
|
||||
.cast<ShapedType>();
|
||||
}
|
||||
return op->getResult(0).getType().cast<ShapedType>();
|
||||
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
@ -206,9 +206,7 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
auto operandType = op.operand().getType().template cast<ShapedType>();
|
||||
auto resultType = getXLAOpResultType<isLHLO>(op);
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
ArrayAttr indexingMapsAttr =
|
||||
static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter);
|
||||
ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter);
|
||||
if (!indexingMapsAttr) return failure();
|
||||
|
||||
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
|
||||
@ -241,7 +239,7 @@ class BroadcastInDimConverter
|
||||
using DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
|
||||
ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) const {
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<isLHLO>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
@ -278,29 +276,6 @@ class BroadcastInDimConverter
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class TransposeConverter
|
||||
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO> {
|
||||
public:
|
||||
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultType.getRank());
|
||||
for (auto permutation : llvm::enumerate(op.permutation())) {
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
b->getAffineDimExpr(permutation.index());
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern for the special case where reshape is adding or removing a dimension
|
||||
/// of size 1. These can be lowered to a linalg.generic op.
|
||||
///
|
||||
@ -314,15 +289,33 @@ class TransposeConverter
|
||||
/// can have indexing maps
|
||||
/// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1,
|
||||
/// d2)>]
|
||||
|
||||
// TODO(ravishankarm): This pattern needs to be removed. The general reshape
|
||||
// lowering hits a corner case where the following sequence of operations
|
||||
// cannot be fused cause the resulting indexing map is not invertible.
|
||||
//
|
||||
// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
// affine_map<(d0, d1, d2) -> (d2)>]
|
||||
// : tensor<5x5xf32> into tensor<5x1x5xf32>
|
||||
// %f = linalg.generic
|
||||
// {...
|
||||
// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
// affine_map<(d0, d1, d2) -> (d0, d2)>],
|
||||
// iterator_types = ["parallel", "parallel", "parallel"]} %r {..}
|
||||
// : tensor<5x1x5xf32> -> tensor<5x5xf32>
|
||||
//
|
||||
// The resolution of this requires a canonicalization on linalg ops where the
|
||||
// dims of size 1 are removed. This pattern can be removed after that.
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class ReshapeAddRemoveDimConverter
|
||||
: public DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
|
||||
OpTy, isLHLO> {
|
||||
public:
|
||||
using DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
|
||||
OpTy, isLHLO>::DataMovementOpConverter;
|
||||
ReshapeAddRemoveDimConverter(MLIRContext* context)
|
||||
: DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
|
||||
OpTy, isLHLO>(context, 100) {}
|
||||
|
||||
ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const {
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto operandType =
|
||||
@ -373,6 +366,106 @@ class ReshapeAddRemoveDimConverter
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class TransposeConverter
|
||||
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO> {
|
||||
public:
|
||||
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultType.getRank());
|
||||
for (auto permutation : llvm::enumerate(op.permutation())) {
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
b->getAffineDimExpr(permutation.index());
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
}
|
||||
};
|
||||
|
||||
// Converts reshape ops that can be proven to be either a collapse of dimensions
|
||||
// or expansion of dimensions of the operand.
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
OpTy reshapeOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
|
||||
return failure();
|
||||
ShapedType operandType =
|
||||
reshapeOp.operand().getType().template cast<ShapedType>();
|
||||
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
|
||||
|
||||
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// TODO(ravishankarm): To make this pattern not match the pattern that
|
||||
// ReshapeAddRemoveDimConverter is for, check that condition here. Remove
|
||||
// this when ReshapeAddRemoveDimConverter pattern is removed.
|
||||
if (ReshapeAddRemoveDimConverter<OpTy, isLHLO>::getIndexingMapsAttr(
|
||||
reshapeOp, &rewriter))
|
||||
return failure();
|
||||
|
||||
// Compute the reassociation maps for the linalg operation.
|
||||
ArrayRef<int64_t> srcShape =
|
||||
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
|
||||
: resultType.getShape());
|
||||
ArrayRef<int64_t> dstShape =
|
||||
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
|
||||
: operandType.getShape());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<SmallVector<AffineExpr, 4>, 4> exprs(dstShape.size());
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= srcShape[currSrcDim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in dstShape is not 1, treat subsequent dims in
|
||||
// srcShape which are 1 to be collapsed.
|
||||
if (currDstDim == dstShape.size() - 1 ||
|
||||
dstShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||
exprs[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
currDstDim++;
|
||||
}
|
||||
if (currSrcDim != srcShape.size()) return failure();
|
||||
|
||||
SmallVector<ArrayRef<AffineExpr>, 4> reassociationMaps;
|
||||
for (auto& expr : exprs) reassociationMaps.push_back(expr);
|
||||
|
||||
if (isLHLO) {
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
reshapeOp.getLoc(), resultType, args[0], reassociationMaps);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, args[0], reassociationMaps);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
|
||||
@ -578,8 +671,6 @@ namespace xla_hlo {
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastInDimConverter<xla_hlo::BroadcastInDimOp, false>,
|
||||
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
@ -599,7 +690,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>>(context);
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
Loading…
Reference in New Issue
Block a user