[XLA][MLIR] Insert linalg.reshape when lowering LHLO BroadcastInDimOp.

If size-1 dimensions require expansion, we insert linalg.reshape to get rid of them.

PiperOrigin-RevId: 314918927
Change-Id: I3e9cd380b93421858343ba299e659c0396ca7bcf
This commit is contained in:
Alexander Belyaev 2020-06-05 06:40:28 -07:00 committed by TensorFlower Gardener
parent 079520f252
commit 967782ff2a
2 changed files with 243 additions and 27 deletions

View File

@ -228,32 +228,54 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_in_dim_with_expansion // CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @broadcast_in_dim_with_expansion(%operand: memref<5x7x1xf32>, func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<7x10x6x4x5xf32>) { %result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> () } : (memref<5xf32>, memref<5x10xf32>) -> ()
return return
} }
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// ----- // -----
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[REASSOCIATION:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_in_dim_scalar // CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @broadcast_in_dim_scalar(%operand: memref<f32>, func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<7x10x6xf32>) { %result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64> broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<f32>, memref<7x10x6xf32>) -> () } : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return return
} }
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[CONST]] : f32 // CHECK-NEXT: linalg.yield %[[CONST]] : f32
@ -262,19 +284,39 @@ func @broadcast_in_dim_scalar(%operand: memref<f32>,
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64> broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> () } : (memref<1xf32>, memref<1x5xf32>) -> ()
return return
} }
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// ----- // -----
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
// -----
// CHECK-LABEL: func @constant // CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) { func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) { "xla_lhlo.constant"(%value) {

View File

@ -271,16 +271,17 @@ class BroadcastConverter
} }
}; };
template <typename OpTy, bool isLHLO = true> class HloBroadcastInDimConverter
class BroadcastInDimConverter : public DataMovementOpConverter<HloBroadcastInDimConverter,
: public DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, xla_hlo::BroadcastInDimOp, false> {
OpTy, isLHLO> {
public: public:
using DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<HloBroadcastInDimConverter,
isLHLO>::DataMovementOpConverter; xla_hlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) { static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp,
auto resultType = getXLAOpResultType<isLHLO>(broadcastOp); Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType = auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank(); unsigned nloops = resultType.getRank();
@ -302,8 +303,6 @@ class BroadcastInDimConverter
int size = broadcastDim.value().getSExtValue(); int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 && bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1; resultType.getShape()[size] != 1;
// TODO(pifon): Add support for args with dynamic shapes for the case
// when a dimension of size 1 is broadcasted into dim of size N.
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size)); : b->getAffineDimExpr(size));
} }
@ -314,6 +313,181 @@ class BroadcastInDimConverter
} }
}; };
class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
Value operand = std::get<0>(operand_and_dims);
auto broadcast_dims = std::get<1>(operand_and_dims);
auto loc = op.getLoc();
auto nloops = result_type.getRank();
auto operand_type = operand.getType().cast<MemRefType>();
// For a degenerate case, i.e. broadcasting with expansion of
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
// Instead the value is loaded and used directly in `linalg.yield`.
if (operand_type.getRank() == 1 &&
operand_type.getDimSize(0) <
result_type.getDimSize(broadcast_dims.front())) {
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1),
rewriter.getAffineMapArrayAttr(
{rewriter.getMultiDimIdentityMap(nloops)}),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArgument(result_type.getElementType());
rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, val);
} else {
ArrayAttr indexingMapsAttr = getIndexingMapsAttr(
op, broadcast_dims, result_shape, operand_type, &rewriter);
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(operand_type.getElementType());
block->addArgument(result_type.getElementType());
rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, block->getArgument(0));
}
rewriter.replaceOp(op, llvm::None);
return success();
}
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
Value result = operand_adaptor.output();
auto result_type = result.getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
SmallVector<int64_t, 2> operand_strides;
int64_t operand_offset;
if (failed(getStridesAndOffset(operand_type, operand_strides,
operand_offset))) {
op.emitOpError() << "Failed to get offset and strides.";
}
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list;
SmallVector<AffineExpr, 2> collapsed_dims;
for (const auto& item :
enumerate(op.broadcast_dimensions().getIntValues())) {
size_t index = item.index();
int dim = item.value().getSExtValue();
collapsed_dims.push_back(rewriter.getAffineDimExpr(index));
bool expansion_needed =
operand_shape[index] == 1 && result_shape[dim] != 1;
if (expansion_needed) {
continue;
}
new_shape.push_back(operand_shape[index]);
new_strides.push_back(operand_strides[index]);
broadcast_dims.push_back(dim);
collapsed_dims_list.push_back(collapsed_dims);
collapsed_dims.clear();
}
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
// and all dimensions need expansion. Such memref will be reshaped to a 1D
// memref with a single element. New shape and strides needs to be updated
// accordingly.
if (collapsed_dims_list.empty()) {
collapsed_dims_list.push_back({});
new_shape.push_back(1);
new_strides.push_back(1);
broadcast_dims.push_back(0);
}
for (const auto& dims : collapsed_dims) {
collapsed_dims_list.back().push_back(dims);
}
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced.
if (new_shape.size() < operand_shape.size()) {
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
for (const auto& dims : collapsed_dims_list)
reassociation_maps.push_back(dims);
auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(),
reassociation_maps);
}
return std::make_pair(operand, broadcast_dims);
}
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType, Builder* b) const {
unsigned nloops = resultShape.size();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
int size = broadcastDim.value();
bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
if (expansion_needed) {
op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion.");
}
dimExprs.push_back(b->getAffineDimExpr(size));
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
/// Pattern for the special case where reshape is adding or removing a dimension /// Pattern for the special case where reshape is adding or removing a dimension
/// of size 1. These can be lowered to a linalg.generic op. /// of size 1. These can be lowered to a linalg.generic op.
/// ///
@ -639,9 +813,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>, patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
BroadcastInDimConverter<xla_lhlo::BroadcastInDimOp>,
ConstConverter, ConstConverter,
IotaConverter, IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>, PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>, PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>, PointwiseToLinalgConverter<xla_lhlo::AndOp>,
@ -742,7 +916,7 @@ namespace xla_hlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context, void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>, patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
BroadcastInDimConverter<xla_hlo::BroadcastInDimOp, false>, HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>, PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>, PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>, PointwiseToLinalgConverter<xla_hlo::AndOp, false>,