[XLA][MLIR] Insert linalg.reshape when lowering LHLO BroadcastInDimOp.
If size-1 dimensions require expansion, we insert linalg.reshape to get rid of them. PiperOrigin-RevId: 314918927 Change-Id: I3e9cd380b93421858343ba299e659c0396ca7bcf
This commit is contained in:
parent
079520f252
commit
967782ff2a
@ -228,32 +228,54 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)>
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @broadcast_in_dim_with_expansion
|
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
|
||||||
func @broadcast_in_dim_with_expansion(%operand: memref<5x7x1xf32>,
|
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
|
||||||
%result: memref<7x10x6x4x5xf32>) {
|
%result: memref<5x10xf32>) {
|
||||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||||
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
|
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||||
} : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> ()
|
} : (memref<5xf32>, memref<5x10xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// CHECK-NOT: linalg.reshape
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()>
|
// CHECK-DAG: #[[REASSOCIATION:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
// CHECK-LABEL: func @broadcast_in_dim_scalar
|
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
|
||||||
func @broadcast_in_dim_scalar(%operand: memref<f32>,
|
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
||||||
%result: memref<7x10x6xf32>) {
|
%result: memref<5x10x100xf32>) {
|
||||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||||
broadcast_dimensions = dense<[]> : tensor<0xi64>
|
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
|
||||||
} : (memref<f32>, memref<7x10x6xf32>) -> ()
|
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
|
||||||
|
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
|
||||||
|
// CHECK: linalg.generic {{{.*}}indexing_maps =
|
||||||
|
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1) -> ()>
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
|
||||||
|
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
|
||||||
|
%result: memref<5x10xf32>) {
|
||||||
|
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||||
|
broadcast_dimensions = dense<[]> : tensor<0xi64>
|
||||||
|
} : (memref<f32>, memref<5x10xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-NOT: linalg.reshape
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
|
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
|
||||||
@ -262,19 +284,39 @@ func @broadcast_in_dim_scalar(%operand: memref<f32>,
|
|||||||
|
|
||||||
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
|
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
|
||||||
func @broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) {
|
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
|
||||||
|
%result: memref<1x5xf32>) {
|
||||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||||
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||||
} : (memref<1xf32>, memref<1x5xf32>) -> ()
|
} : (memref<1xf32>, memref<1x5xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// CHECK-NOT: linalg.reshape
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
|
||||||
|
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
||||||
|
%result: memref<5x5xf32>) {
|
||||||
|
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||||
|
broadcast_dimensions = dense<[1]> : tensor<1xi64>
|
||||||
|
} : (memref<1xf32>, memref<5x5xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-NOT: linalg.reshape
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
|
||||||
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
|
||||||
|
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @constant
|
// CHECK-LABEL: func @constant
|
||||||
func @constant(%value: memref<i32>) {
|
func @constant(%value: memref<i32>) {
|
||||||
"xla_lhlo.constant"(%value) {
|
"xla_lhlo.constant"(%value) {
|
||||||
|
@ -271,16 +271,17 @@ class BroadcastConverter
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename OpTy, bool isLHLO = true>
|
class HloBroadcastInDimConverter
|
||||||
class BroadcastInDimConverter
|
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||||
: public DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>,
|
xla_hlo::BroadcastInDimOp, false> {
|
||||||
OpTy, isLHLO> {
|
|
||||||
public:
|
public:
|
||||||
using DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, OpTy,
|
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||||
isLHLO>::DataMovementOpConverter;
|
xla_hlo::BroadcastInDimOp,
|
||||||
|
false>::DataMovementOpConverter;
|
||||||
|
|
||||||
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
|
static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp,
|
||||||
auto resultType = getXLAOpResultType<isLHLO>(broadcastOp);
|
Builder* b) {
|
||||||
|
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||||
auto operandType =
|
auto operandType =
|
||||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||||
unsigned nloops = resultType.getRank();
|
unsigned nloops = resultType.getRank();
|
||||||
@ -302,8 +303,6 @@ class BroadcastInDimConverter
|
|||||||
int size = broadcastDim.value().getSExtValue();
|
int size = broadcastDim.value().getSExtValue();
|
||||||
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
|
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
|
||||||
resultType.getShape()[size] != 1;
|
resultType.getShape()[size] != 1;
|
||||||
// TODO(pifon): Add support for args with dynamic shapes for the case
|
|
||||||
// when a dimension of size 1 is broadcasted into dim of size N.
|
|
||||||
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
||||||
: b->getAffineDimExpr(size));
|
: b->getAffineDimExpr(size));
|
||||||
}
|
}
|
||||||
@ -314,6 +313,181 @@ class BroadcastInDimConverter
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LhloBroadcastInDimConverter
|
||||||
|
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
|
||||||
|
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
||||||
|
auto result_shape = result_type.getShape();
|
||||||
|
|
||||||
|
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
|
||||||
|
|
||||||
|
Value operand = std::get<0>(operand_and_dims);
|
||||||
|
auto broadcast_dims = std::get<1>(operand_and_dims);
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto nloops = result_type.getRank();
|
||||||
|
auto operand_type = operand.getType().cast<MemRefType>();
|
||||||
|
|
||||||
|
// For a degenerate case, i.e. broadcasting with expansion of
|
||||||
|
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
|
||||||
|
// Instead the value is loaded and used directly in `linalg.yield`.
|
||||||
|
if (operand_type.getRank() == 1 &&
|
||||||
|
operand_type.getDimSize(0) <
|
||||||
|
result_type.getDimSize(broadcast_dims.front())) {
|
||||||
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value val =
|
||||||
|
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||||
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
||||||
|
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1),
|
||||||
|
rewriter.getAffineMapArrayAttr(
|
||||||
|
{rewriter.getMultiDimIdentityMap(nloops)}),
|
||||||
|
GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||||
|
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||||
|
|
||||||
|
auto* region = &linalgOp.region();
|
||||||
|
auto* block = rewriter.createBlock(region, region->end());
|
||||||
|
block->addArgument(result_type.getElementType());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
rewriter.create<linalg::YieldOp>(loc, val);
|
||||||
|
} else {
|
||||||
|
ArrayAttr indexingMapsAttr = getIndexingMapsAttr(
|
||||||
|
op, broadcast_dims, result_shape, operand_type, &rewriter);
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
|
||||||
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, llvm::None,
|
||||||
|
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
||||||
|
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
|
||||||
|
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||||
|
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||||
|
|
||||||
|
auto* region = &linalgOp.region();
|
||||||
|
auto* block = rewriter.createBlock(region, region->end());
|
||||||
|
block->addArguments(operand_type.getElementType());
|
||||||
|
block->addArgument(result_type.getElementType());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
rewriter.create<linalg::YieldOp>(loc, block->getArgument(0));
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
||||||
|
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
||||||
|
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
|
||||||
|
Value operand = operand_adaptor.operand();
|
||||||
|
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
||||||
|
auto operand_shape = operand_type.getShape();
|
||||||
|
|
||||||
|
Value result = operand_adaptor.output();
|
||||||
|
auto result_type = result.getType().cast<MemRefType>();
|
||||||
|
auto result_shape = result_type.getShape();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> operand_strides;
|
||||||
|
int64_t operand_offset;
|
||||||
|
if (failed(getStridesAndOffset(operand_type, operand_strides,
|
||||||
|
operand_offset))) {
|
||||||
|
op.emitOpError() << "Failed to get offset and strides.";
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
|
||||||
|
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list;
|
||||||
|
SmallVector<AffineExpr, 2> collapsed_dims;
|
||||||
|
for (const auto& item :
|
||||||
|
enumerate(op.broadcast_dimensions().getIntValues())) {
|
||||||
|
size_t index = item.index();
|
||||||
|
int dim = item.value().getSExtValue();
|
||||||
|
|
||||||
|
collapsed_dims.push_back(rewriter.getAffineDimExpr(index));
|
||||||
|
|
||||||
|
bool expansion_needed =
|
||||||
|
operand_shape[index] == 1 && result_shape[dim] != 1;
|
||||||
|
if (expansion_needed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
new_shape.push_back(operand_shape[index]);
|
||||||
|
new_strides.push_back(operand_strides[index]);
|
||||||
|
broadcast_dims.push_back(dim);
|
||||||
|
|
||||||
|
collapsed_dims_list.push_back(collapsed_dims);
|
||||||
|
collapsed_dims.clear();
|
||||||
|
}
|
||||||
|
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
|
||||||
|
// and all dimensions need expansion. Such memref will be reshaped to a 1D
|
||||||
|
// memref with a single element. New shape and strides needs to be updated
|
||||||
|
// accordingly.
|
||||||
|
if (collapsed_dims_list.empty()) {
|
||||||
|
collapsed_dims_list.push_back({});
|
||||||
|
new_shape.push_back(1);
|
||||||
|
new_strides.push_back(1);
|
||||||
|
broadcast_dims.push_back(0);
|
||||||
|
}
|
||||||
|
for (const auto& dims : collapsed_dims) {
|
||||||
|
collapsed_dims_list.back().push_back(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
|
||||||
|
// reduced.
|
||||||
|
if (new_shape.size() < operand_shape.size()) {
|
||||||
|
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
|
||||||
|
for (const auto& dims : collapsed_dims_list)
|
||||||
|
reassociation_maps.push_back(dims);
|
||||||
|
auto new_memref_type = MemRefType::get(
|
||||||
|
new_shape, operand_type.getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(new_strides, operand_offset,
|
||||||
|
rewriter.getContext()));
|
||||||
|
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
|
||||||
|
operand_adaptor.operand(),
|
||||||
|
reassociation_maps);
|
||||||
|
}
|
||||||
|
return std::make_pair(operand, broadcast_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op,
|
||||||
|
ArrayRef<int64_t> broadcastDims,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
MemRefType operandType, Builder* b) const {
|
||||||
|
unsigned nloops = resultShape.size();
|
||||||
|
|
||||||
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||||
|
if (operandType.getRank() == 0) {
|
||||||
|
return b->getAffineMapArrayAttr(
|
||||||
|
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operandShape = operandType.getShape();
|
||||||
|
SmallVector<AffineExpr, 4> dimExprs;
|
||||||
|
dimExprs.reserve(nloops);
|
||||||
|
|
||||||
|
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
|
||||||
|
int size = broadcastDim.value();
|
||||||
|
bool expansion_needed =
|
||||||
|
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
|
||||||
|
if (expansion_needed) {
|
||||||
|
op.emitOpError(
|
||||||
|
"BroadcastInDimOp lowering to Linalg does not support size-1 "
|
||||||
|
"dimensions expansion.");
|
||||||
|
}
|
||||||
|
dimExprs.push_back(b->getAffineDimExpr(size));
|
||||||
|
}
|
||||||
|
return b->getAffineMapArrayAttr(
|
||||||
|
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Pattern for the special case where reshape is adding or removing a dimension
|
/// Pattern for the special case where reshape is adding or removing a dimension
|
||||||
/// of size 1. These can be lowered to a linalg.generic op.
|
/// of size 1. These can be lowered to a linalg.generic op.
|
||||||
///
|
///
|
||||||
@ -639,9 +813,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||||
BroadcastInDimConverter<xla_lhlo::BroadcastInDimOp>,
|
|
||||||
ConstConverter,
|
ConstConverter,
|
||||||
IotaConverter,
|
IotaConverter,
|
||||||
|
LhloBroadcastInDimConverter,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||||
@ -742,7 +916,7 @@ namespace xla_hlo {
|
|||||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
||||||
BroadcastInDimConverter<xla_hlo::BroadcastInDimOp, false>,
|
HloBroadcastInDimConverter,
|
||||||
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||||
|
Loading…
Reference in New Issue
Block a user