[MLIR][XLA] Refactor lhlo_legalize_to_parallel_loops.

Extract the logic to apply lhlo IR in a block with memref arguments to the Values with element types of the respective input memrefs.

PiperOrigin-RevId: 304992790
Change-Id: I23ec0c13a300527a9544e900df3a506542deedd2
This commit is contained in:
Alexander Belyaev 2020-04-06 03:31:55 -07:00 committed by TensorFlower Gardener
parent e853fd96f1
commit 9fbe2a4de4
2 changed files with 67 additions and 50 deletions

View File

@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: loop.yield
@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]]
// CHECK: }
// CHECK: loop.yield
@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: loop.yield
@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: loop.yield

View File

@ -29,38 +29,49 @@ namespace mlir {
namespace xla_lhlo {
namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
// single output buffer to make it compatible with `operands` that have element
// types of the respective buffers. Returns the computed value.
//
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
// with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
// <LHLO_ops>
//
// inserts necessary alloc and store ops to compute and return result that has
// `i1` type.
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
Block* lhlo_block, OpBuilder* b) {
SmallVector<Value, 2> arg_bufs;
for (auto arg_type : lhlo_block->getArgumentTypes()) {
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
}
for (auto operand : llvm::enumerate(operands)) {
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
}
// Clone the ops from `lhlo_block`.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(), arg_bufs);
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = b->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
return b->create<LoadOp>(loc, arg_bufs.back());
}
// Converts a block with LHLO ops and with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// into a reduction operator of loop.reduce by doing buffer allocation for
// scalar arguments and the result of `loop.reduce` to make it compatible with
// LHLO ops.
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
Block* lhlo_block,
ConversionPatternRewriter* rewriter) {
Block* lhlo_block, OpBuilder* b) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
rewriter->setInsertionPointToStart(&loop_reduce_op_body);
// Allocate buffers to hold arguments of reduction operator block to stay
// compatible with the LHLO dialect ops in the reduction body.
Value elem_arg = lhlo_block->getArgument(0);
Value elem_buf =
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
Value acc_arg = lhlo_block->getArgument(1);
Value acc_buf =
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(),
ValueRange{elem_buf, acc_buf, acc_buf});
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = rewriter->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(&loop_reduce_op_body);
b->create<loop::ReduceReturnOp>(
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
lhlo_block, b));
}
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
@ -79,9 +90,8 @@ struct MappedIvs {
SmallVector<Value, 2> ivs;
};
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ArrayRef<BlockArgument> ivs,
ArrayRef<BlockArgument> window_ivs,
OpBuilder* b) {
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs,
ValueRange window_ivs, OpBuilder* b) {
MappedIvs mapped_ivs;
if (!op.window_strides().hasValue()) {
@ -256,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
SmallVector<Value, 1> out_indices;
if (outer != nullptr) {
out_indices.reserve(outer.getNumLoops());
for (auto& iv : outer.getInductionVars()) {
for (Value iv : outer.getInductionVars()) {
out_indices.push_back(iv);
}
} else {
@ -268,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// Load the element to reduce.
SmallVector<Value, 2> indices;
indices.reserve(operand_shape.size());
Block::args_iterator outer_ivs_it =
outer ? outer.getInductionVars().begin() : nullptr;
Block::args_iterator inner_ivs_it = inner.getInductionVars().begin();
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
: *outer_ivs_it++);
if (outer) {
auto inner_ivs_it = inner.getInductionVars().begin();
auto outer_ivs_it = outer.getInductionVars().begin();
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
: *outer_ivs_it++);
}
} else {
indices = ValueRange(inner.getInductionVars());
}
rewriter->setInsertionPointToStart(inner.getBody());
@ -395,9 +409,8 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(
loc, reduction_result, xla_output,
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
rewriter->create<StoreOp>(loc, reduction_result, xla_output,
ValueRange{output_ivs});
return std::make_pair(output_loop, window_loop);
}