[MLIR][XLA] Refactor lhlo_legalize_to_parallel_loops.
Extract the logic to apply lhlo IR in a block with memref arguments to the Values with element types of the respective input memrefs. PiperOrigin-RevId: 304992790 Change-Id: I23ec0c13a300527a9544e900df3a506542deedd2
This commit is contained in:
parent
e853fd96f1
commit
9fbe2a4de4
@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]]
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
|
@ -29,38 +29,49 @@ namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
|
||||
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
|
||||
// single output buffer to make it compatible with `operands` that have element
|
||||
// types of the respective buffers. Returns the computed value.
|
||||
//
|
||||
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
|
||||
// with signature:
|
||||
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
|
||||
// <LHLO_ops>
|
||||
//
|
||||
// inserts necessary alloc and store ops to compute and return result that has
|
||||
// `i1` type.
|
||||
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
||||
Block* lhlo_block, OpBuilder* b) {
|
||||
SmallVector<Value, 2> arg_bufs;
|
||||
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
||||
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||
}
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||
}
|
||||
// Clone the ops from `lhlo_block`.
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(lhlo_block->getArguments(), arg_bufs);
|
||||
for (auto& nested : lhlo_block->without_terminator()) {
|
||||
auto clone = b->clone(nested, mapping);
|
||||
mapping.map(nested.getResults(), clone->getResults());
|
||||
}
|
||||
return b->create<LoadOp>(loc, arg_bufs.back());
|
||||
}
|
||||
|
||||
// Converts a block with LHLO ops and with signature:
|
||||
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// into a reduction operator of loop.reduce by doing buffer allocation for
|
||||
// scalar arguments and the result of `loop.reduce` to make it compatible with
|
||||
// LHLO ops.
|
||||
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
|
||||
Block* lhlo_block,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
Block* lhlo_block, OpBuilder* b) {
|
||||
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
|
||||
rewriter->setInsertionPointToStart(&loop_reduce_op_body);
|
||||
|
||||
// Allocate buffers to hold arguments of reduction operator block to stay
|
||||
// compatible with the LHLO dialect ops in the reduction body.
|
||||
Value elem_arg = lhlo_block->getArgument(0);
|
||||
Value elem_buf =
|
||||
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
|
||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
|
||||
Value acc_arg = lhlo_block->getArgument(1);
|
||||
Value acc_buf =
|
||||
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
|
||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
|
||||
|
||||
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(lhlo_block->getArguments(),
|
||||
ValueRange{elem_buf, acc_buf, acc_buf});
|
||||
for (auto& nested : lhlo_block->without_terminator()) {
|
||||
auto clone = rewriter->clone(nested, mapping);
|
||||
mapping.map(nested.getResults(), clone->getResults());
|
||||
}
|
||||
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
|
||||
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
|
||||
OpBuilder::InsertionGuard guard(*b);
|
||||
b->setInsertionPointToStart(&loop_reduce_op_body);
|
||||
b->create<loop::ReduceReturnOp>(
|
||||
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
|
||||
lhlo_block, b));
|
||||
}
|
||||
|
||||
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
|
||||
@ -79,9 +90,8 @@ struct MappedIvs {
|
||||
SmallVector<Value, 2> ivs;
|
||||
};
|
||||
|
||||
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ArrayRef<BlockArgument> ivs,
|
||||
ArrayRef<BlockArgument> window_ivs,
|
||||
OpBuilder* b) {
|
||||
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs,
|
||||
ValueRange window_ivs, OpBuilder* b) {
|
||||
MappedIvs mapped_ivs;
|
||||
|
||||
if (!op.window_strides().hasValue()) {
|
||||
@ -256,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
SmallVector<Value, 1> out_indices;
|
||||
if (outer != nullptr) {
|
||||
out_indices.reserve(outer.getNumLoops());
|
||||
for (auto& iv : outer.getInductionVars()) {
|
||||
for (Value iv : outer.getInductionVars()) {
|
||||
out_indices.push_back(iv);
|
||||
}
|
||||
} else {
|
||||
@ -268,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// Load the element to reduce.
|
||||
SmallVector<Value, 2> indices;
|
||||
indices.reserve(operand_shape.size());
|
||||
Block::args_iterator outer_ivs_it =
|
||||
outer ? outer.getInductionVars().begin() : nullptr;
|
||||
Block::args_iterator inner_ivs_it = inner.getInductionVars().begin();
|
||||
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
|
||||
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
|
||||
: *outer_ivs_it++);
|
||||
|
||||
if (outer) {
|
||||
auto inner_ivs_it = inner.getInductionVars().begin();
|
||||
auto outer_ivs_it = outer.getInductionVars().begin();
|
||||
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
|
||||
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
|
||||
: *outer_ivs_it++);
|
||||
}
|
||||
} else {
|
||||
indices = ValueRange(inner.getInductionVars());
|
||||
}
|
||||
|
||||
rewriter->setInsertionPointToStart(inner.getBody());
|
||||
@ -395,9 +409,8 @@ class ReduceWindowOpConverter
|
||||
|
||||
Value reduction_result = *window_loop.getResults().begin();
|
||||
auto output_ivs = output_loop.getInductionVars();
|
||||
rewriter->create<StoreOp>(
|
||||
loc, reduction_result, xla_output,
|
||||
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
|
||||
rewriter->create<StoreOp>(loc, reduction_result, xla_output,
|
||||
ValueRange{output_ivs});
|
||||
return std::make_pair(output_loop, window_loop);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user