Generalize reshape op lowering to cover expansion/collapse of dims.

With addition of linalg.tensor_reshape, the reshape op lowering for
both HLO and LHLO can be generalized to cover cases where reshape is
an expansion or collapse of dimensions.

PiperOrigin-RevId: 308697032
Change-Id: I461508211a0bc86a60f41b656b5e5af5b45eb2f8
This commit is contained in:
Mahesh Ravishankar 2020-04-27 14:20:41 -07:00 committed by TensorFlower Gardener
parent a26e04d948
commit c6eaf562e7
2 changed files with 192 additions and 39 deletions

View File

@ -359,3 +359,62 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
func @reshape_collapse_single_dim
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
return %0 : tensor<1x784xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_collapse_single_dim
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// -----
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
return %0 : tensor<2x4x3xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-LABEL: func @reshape_collapse
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// -----
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-LABEL: func @reshape_expand
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// -----
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
return %0 : tensor<1x4x2xf32>
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @reshape_single_expand
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
// -----
func @reshape_multiple_collapse
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
return %0 : tensor<1x4x5x6xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-LABEL: func @reshape_multiple_collapse
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]

View File

@ -47,14 +47,14 @@ ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) {
return b->getArrayAttr(iteratorTypes);
}
template <bool isLHLO = true>
Value getResultValue(Operation* op) {
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
}
template <bool isLHLO = true>
ShapedType getXLAOpResultType(Operation* op) {
if (isLHLO) {
return op->getOperand(op->getNumOperands() - 1)
.getType()
.cast<ShapedType>();
}
return op->getResult(0).getType().cast<ShapedType>();
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
}
template <bool isLHLO = true>
@ -206,9 +206,7 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto operandType = op.operand().getType().template cast<ShapedType>();
auto resultType = getXLAOpResultType<isLHLO>(op);
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
ArrayAttr indexingMapsAttr =
static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter);
ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter);
if (!indexingMapsAttr) return failure();
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
@ -241,7 +239,7 @@ class BroadcastInDimConverter
using DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) const {
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<isLHLO>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
@ -278,29 +276,6 @@ class BroadcastInDimConverter
}
};
template <typename OpTy, bool isLHLO = true>
class TransposeConverter
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
/// Pattern for the special case where reshape is adding or removing a dimension
/// of size 1. These can be lowered to a linalg.generic op.
///
@ -314,15 +289,33 @@ class TransposeConverter
/// can have indexing maps
/// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1,
/// d2)>]
// TODO(ravishankarm): This pattern needs to be removed. The general reshape
// lowering hits a corner case where the following sequence of operations
// cannot be fused cause the resulting indexing map is not invertible.
//
// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>,
// affine_map<(d0, d1, d2) -> (d2)>]
// : tensor<5x5xf32> into tensor<5x1x5xf32>
// %f = linalg.generic
// {...
// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
// affine_map<(d0, d1, d2) -> (d0, d2)>],
// iterator_types = ["parallel", "parallel", "parallel"]} %r {..}
// : tensor<5x1x5xf32> -> tensor<5x5xf32>
//
// The resolution of this requires a canonicalization on linalg ops where the
// dims of size 1 are removed. This pattern can be removed after that.
template <typename OpTy, bool isLHLO = true>
class ReshapeAddRemoveDimConverter
: public DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO> {
public:
using DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO>::DataMovementOpConverter;
ReshapeAddRemoveDimConverter(MLIRContext* context)
: DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO>(context, 100) {}
ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const {
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto operandType =
@ -373,6 +366,106 @@ class ReshapeAddRemoveDimConverter
}
};
template <typename OpTy, bool isLHLO = true>
class TransposeConverter
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
// Converts reshape ops that can be proven to be either a collapse of dimensions
// or expansion of dimensions of the operand.
template <typename OpTy, bool isLHLO = true>
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
return failure();
ShapedType operandType =
reshapeOp.operand().getType().template cast<ShapedType>();
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
// TODO(ravishankarm): To make this pattern not match the pattern that
// ReshapeAddRemoveDimConverter is for, check that condition here. Remove
// this when ReshapeAddRemoveDimConverter pattern is removed.
if (ReshapeAddRemoveDimConverter<OpTy, isLHLO>::getIndexingMapsAttr(
reshapeOp, &rewriter))
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
: resultType.getShape());
ArrayRef<int64_t> dstShape =
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
: operandType.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<SmallVector<AffineExpr, 4>, 4> exprs(dstShape.size());
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in dstShape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 ||
dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
exprs[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
return failure();
}
currDstDim++;
}
if (currSrcDim != srcShape.size()) return failure();
SmallVector<ArrayRef<AffineExpr>, 4> reassociationMaps;
for (auto& expr : exprs) reassociationMaps.push_back(expr);
if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMaps);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMaps);
}
return success();
}
};
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
public:
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
@ -578,8 +671,6 @@ namespace xla_hlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastInDimConverter<xla_hlo::BroadcastInDimOp, false>,
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
@ -599,7 +690,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>>(context);
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {