[XLA][MLIR] Lower xla_lhlo.ReduceWindowOp to parallel loops.
PiperOrigin-RevId: 302892920 Change-Id: Idefaa2f01979748cb30c80820823373851393b34
This commit is contained in:
parent
b857f0aaec
commit
a8aa33c45e
|
@ -125,3 +125,77 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||||
// CHECK: loop.yield
|
// CHECK: loop.yield
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
|
%init: memref<f32>,
|
||||||
|
%result: memref<56x56xf32>) {
|
||||||
|
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||||
|
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
"xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||||
|
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"xla_lhlo.terminator"() : () -> ()
|
||||||
|
}) {
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 2]> : tensor<2xi64>
|
||||||
|
} : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reduce_window(
|
||||||
|
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
|
||||||
|
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
|
||||||
|
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
|
||||||
|
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant 1 : i1
|
||||||
|
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||||
|
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||||
|
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
|
||||||
|
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
|
||||||
|
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
|
||||||
|
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
||||||
|
// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||||
|
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
|
||||||
|
// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel
|
||||||
|
// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]])
|
||||||
|
// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]])
|
||||||
|
// CHECK-SAME: init ([[INIT]]) -> f32 {
|
||||||
|
|
||||||
|
// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index
|
||||||
|
// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index
|
||||||
|
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
|
||||||
|
// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]]
|
||||||
|
// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]]
|
||||||
|
|
||||||
|
// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index
|
||||||
|
// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index
|
||||||
|
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
|
||||||
|
// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]]
|
||||||
|
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]]
|
||||||
|
|
||||||
|
// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) {
|
||||||
|
// CHECK: [[OPERAND_ELEM:%.*]] =
|
||||||
|
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
|
||||||
|
// CHECK: loop.yield [[OPERAND_ELEM]] : f32
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: loop.yield [[INIT]] : f32
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||||
|
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||||
|
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||||
|
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||||
|
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||||
|
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||||
|
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: loop.yield
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||||
|
// CHECK: loop.yield
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return
|
||||||
|
// CHECK: }
|
||||||
|
|
|
@ -29,6 +29,50 @@ namespace mlir {
|
||||||
namespace xla_lhlo {
|
namespace xla_lhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Converts a block with LHLO ops and with signature:
|
||||||
|
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
// into a reduction operator of loop.reduce by doing buffer allocation for
|
||||||
|
// scalar arguments and the result of `loop.reduce` to make it compatible with
|
||||||
|
// LHLO ops.
|
||||||
|
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
|
||||||
|
Block* lhlo_block,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
|
||||||
|
rewriter->setInsertionPointToStart(&loop_reduce_op_body);
|
||||||
|
|
||||||
|
// Allocate buffers to hold arguments of reduction operator block to stay
|
||||||
|
// compatible with the LHLO dialect ops in the reduction body.
|
||||||
|
Value elem_arg = lhlo_block->getArgument(0);
|
||||||
|
Value elem_buf =
|
||||||
|
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
|
||||||
|
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
|
||||||
|
Value acc_arg = lhlo_block->getArgument(1);
|
||||||
|
Value acc_buf =
|
||||||
|
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
|
||||||
|
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
|
||||||
|
|
||||||
|
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
mapping.map(lhlo_block->getArguments(),
|
||||||
|
ValueRange{elem_buf, acc_buf, acc_buf});
|
||||||
|
for (auto& nested : lhlo_block->without_terminator()) {
|
||||||
|
auto clone = rewriter->clone(nested, mapping);
|
||||||
|
mapping.map(nested.getResults(), clone->getResults());
|
||||||
|
}
|
||||||
|
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
|
||||||
|
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
|
||||||
|
// extract dimension at runtime.
|
||||||
|
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
||||||
|
size_t dim_index, int64_t dim,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
return dim == ShapedType::kDynamicSize
|
||||||
|
? rewriter->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
||||||
|
: rewriter->create<ConstantIndexOp>(loc, dim);
|
||||||
|
}
|
||||||
|
|
||||||
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
|
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
|
||||||
// The outper `ParallelOp` refers to the parallel loops if there are
|
// The outper `ParallelOp` refers to the parallel loops if there are
|
||||||
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
|
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
|
||||||
|
@ -42,7 +86,7 @@ namespace {
|
||||||
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
|
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||||
// : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
|
// : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
|
||||||
//
|
//
|
||||||
// is converted into:
|
// is roughly converted into:
|
||||||
//
|
//
|
||||||
// %init = load %init_buf[] : memref<f32>
|
// %init = load %init_buf[] : memref<f32>
|
||||||
// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
|
// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
|
||||||
|
@ -67,15 +111,15 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
|
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> args,
|
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
// TODO(b/137624192) Implement variadic reduce.
|
// TODO(b/137624192) Implement variadic reduce.
|
||||||
if (xla_reduce_op.out().size() != 1) return failure();
|
if (xla_reduce_op.out().size() != 1) return failure();
|
||||||
|
|
||||||
loop::ReduceOp reduce_op =
|
loop::ReduceOp reduce_op =
|
||||||
CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter);
|
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
|
||||||
ConvertReductionOperator(xla_reduce_op,
|
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
|
||||||
&reduce_op.reductionOperator().front(), &rewriter);
|
&xla_reduce_op.body().front(), &rewriter);
|
||||||
rewriter.replaceOp(xla_reduce_op, llvm::None);
|
rewriter.replaceOp(xla_reduce_op, llvm::None);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -100,8 +144,8 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
// } : f32
|
// } : f32
|
||||||
// loop.yield
|
// loop.yield
|
||||||
// }
|
// }
|
||||||
loop::ReduceOp CreateParallelLoopsWithReduceOp(
|
loop::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||||
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> args,
|
xla_lhlo::ReduceOp xla_reduce_op,
|
||||||
ConversionPatternRewriter* rewriter) const {
|
ConversionPatternRewriter* rewriter) const {
|
||||||
auto loc = xla_reduce_op.getLoc();
|
auto loc = xla_reduce_op.getLoc();
|
||||||
DenseSet<int> reducing_dims;
|
DenseSet<int> reducing_dims;
|
||||||
|
@ -114,20 +158,13 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
||||||
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
|
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
|
||||||
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
|
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
Type index_type = rewriter->getIndexType();
|
|
||||||
for (auto dim : llvm::enumerate(operand_shape)) {
|
for (auto dim : llvm::enumerate(operand_shape)) {
|
||||||
const bool is_reducing_dim = reducing_dims.count(dim.index());
|
const bool is_reducing_dim = reducing_dims.count(dim.index());
|
||||||
|
|
||||||
Value ub =
|
Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
|
||||||
dim.value() == ShapedType::kDynamicSize
|
rewriter);
|
||||||
? rewriter->create<DimOp>(loc, operand, dim.index()).getResult()
|
Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
: rewriter->create<mlir::ConstantOp>(
|
Value step = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
loc, index_type,
|
|
||||||
rewriter->getIntegerAttr(index_type, dim.value()));
|
|
||||||
Value lb = rewriter->create<mlir::ConstantOp>(
|
|
||||||
loc, index_type, rewriter->getIntegerAttr(index_type, 0));
|
|
||||||
Value step = rewriter->create<mlir::ConstantOp>(
|
|
||||||
loc, index_type, rewriter->getIntegerAttr(index_type, 1));
|
|
||||||
(is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
|
(is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
|
||||||
(is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
|
(is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
|
||||||
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
||||||
|
@ -153,8 +190,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
out_indices.push_back(iv);
|
out_indices.push_back(iv);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out_indices.push_back(rewriter->create<mlir::ConstantOp>(
|
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
|
||||||
loc, index_type, rewriter->getIntegerAttr(index_type, 0)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
|
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
|
||||||
|
@ -175,39 +211,209 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
loc, *xla_reduce_op.operands().begin(), indices);
|
loc, *xla_reduce_op.operands().begin(), indices);
|
||||||
return rewriter->create<loop::ReduceOp>(loc, elem);
|
return rewriter->create<loop::ReduceOp>(loc, elem);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts `xla_lhlo.reduce` reduction operator into `loop.reduce` op by
|
// Pseudocode:
|
||||||
// doing buffer allocation for scalar arguments and the result of
|
// for each index O in output
|
||||||
// `loop.reduce` to make it compatible with LHLO ops.
|
// accumulator = neutral_value
|
||||||
void ConvertReductionOperator(xla_lhlo::ReduceOp xla_reduce_op,
|
// in_bounds = true
|
||||||
Block* loop_reduce_op_body,
|
// for each index W in window
|
||||||
ConversionPatternRewriter* rewriter) const {
|
// for each dimension i from 0 to rank - 1
|
||||||
rewriter->setInsertionPointToStart(loop_reduce_op_body);
|
// index = O[i] * stride[i] + W[i] - pad_low[i]
|
||||||
|
// in_bounds = inbounds && (index `ult` shape[i])
|
||||||
|
// I[i] = index
|
||||||
|
// if (in_bounds)
|
||||||
|
// value = input[I]
|
||||||
|
// else
|
||||||
|
// value = neutral_value
|
||||||
|
// accumulator = reduction_operator(output[O], value)
|
||||||
|
// output[O] = accumulator
|
||||||
|
//
|
||||||
|
// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a
|
||||||
|
// loop::ReduceOp.
|
||||||
|
// The outper `ParallelOp` refers to the parallel loops that traverese output
|
||||||
|
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
|
||||||
|
// reduction windows and `ReduceOp` contains the reduction operator.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
|
// %init: memref<f32>,
|
||||||
|
// %result: memref<56x56xf32>) {
|
||||||
|
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||||
|
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||||
|
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
|
// }) {
|
||||||
|
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||||
|
// window_strides = dense<[2, 2]> : tensor<2xi64>
|
||||||
|
// } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// is roughly converted into:
|
||||||
|
//
|
||||||
|
// %neutral_elem = load %init_buf[] : memref<f32>
|
||||||
|
// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
|
||||||
|
// %result = loop.parallel (%iw, %jw) = (%c0, %c0)
|
||||||
|
// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
|
||||||
|
// %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
|
||||||
|
// %elem = load %operand[%computed_i, %computed_j]
|
||||||
|
// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
|
||||||
|
// loop.reduce(%elem_to_reduce) : f32 {
|
||||||
|
// ^bb0(%arg7: f32, %arg8: f32):
|
||||||
|
// <LHLO ops>
|
||||||
|
// }
|
||||||
|
// loop.yield
|
||||||
|
// }
|
||||||
|
// store %result, %output_buffer[%i, %j] : memref<56x56xf32>
|
||||||
|
// loop.yield
|
||||||
|
// }
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
class ReduceWindowOpConverter
|
||||||
|
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
|
||||||
|
|
||||||
// Allocate buffers to hold arguments of reduction operator block to stay
|
LogicalResult matchAndRewrite(
|
||||||
// compatible with the LHLO dialect ops in the reduction body.
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||||
auto loc = xla_reduce_op.getLoc();
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
Value elem_arg = xla_reduce_op.body().front().getArgument(0);
|
loop::ParallelOp output_loop, window_loop;
|
||||||
Value elem_buf =
|
std::tie(output_loop, window_loop) =
|
||||||
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
|
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
|
||||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body->getArgument(0),
|
&rewriter);
|
||||||
elem_buf);
|
|
||||||
Value acc_arg = xla_reduce_op.body().front().getArgument(1);
|
|
||||||
Value acc_buf =
|
|
||||||
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
|
|
||||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body->getArgument(1),
|
|
||||||
acc_buf);
|
|
||||||
|
|
||||||
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
|
loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
|
||||||
BlockAndValueMapping mapping;
|
xla_reduce_window_op, output_loop, window_loop, &rewriter);
|
||||||
mapping.map(xla_reduce_op.body().front().getArguments(),
|
|
||||||
ValueRange{elem_buf, acc_buf, acc_buf});
|
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
|
||||||
for (auto& nested : xla_reduce_op.body().front().without_terminator()) {
|
&xla_reduce_window_op.body().front(), &rewriter);
|
||||||
auto clone = rewriter->clone(nested, mapping);
|
rewriter.replaceOp(xla_reduce_window_op, llvm::None);
|
||||||
mapping.map(nested.getResults(), clone->getResults());
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::pair<loop::ParallelOp, loop::ParallelOp>
|
||||||
|
CreateParallelLoopsToTraverseOutputAndWindow(
|
||||||
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||||
|
ConversionPatternRewriter* rewriter) const {
|
||||||
|
auto loc = xla_reduce_window_op.getLoc();
|
||||||
|
Value init_value =
|
||||||
|
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
|
||||||
|
|
||||||
|
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
||||||
|
Value xla_output = xla_reduce_window_op.out();
|
||||||
|
auto output_shape = xla_output.getType().cast<MemRefType>().getShape();
|
||||||
|
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
||||||
|
for (auto dim : llvm::enumerate(output_shape)) {
|
||||||
|
parallel_upper.push_back(GetStaticOrDynamicDim(
|
||||||
|
loc, xla_output, dim.index(), dim.value(), rewriter));
|
||||||
|
parallel_lower.push_back(zero);
|
||||||
|
parallel_step.push_back(one);
|
||||||
}
|
}
|
||||||
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
|
auto output_loop = rewriter->create<loop::ParallelOp>(
|
||||||
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
|
loc, parallel_lower, parallel_upper, parallel_step);
|
||||||
|
|
||||||
|
// Create a nested loop that traverses the window.
|
||||||
|
rewriter->setInsertionPointToStart(output_loop.getBody());
|
||||||
|
SmallVector<Value, 2> window_lower, window_upper, window_step;
|
||||||
|
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
|
||||||
|
window_step.push_back(one);
|
||||||
|
window_lower.push_back(zero);
|
||||||
|
window_upper.push_back(
|
||||||
|
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
|
||||||
|
}
|
||||||
|
auto window_loop = rewriter->create<loop::ParallelOp>(
|
||||||
|
loc, window_lower, window_upper, window_step, init_value);
|
||||||
|
|
||||||
|
Value reduction_result = *window_loop.getResults().begin();
|
||||||
|
auto output_ivs = output_loop.getInductionVars();
|
||||||
|
rewriter->create<StoreOp>(
|
||||||
|
loc, reduction_result, xla_output,
|
||||||
|
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
|
||||||
|
return std::make_pair(output_loop, window_loop);
|
||||||
|
}
|
||||||
|
|
||||||
|
loop::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||||
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||||
|
loop::ParallelOp output_loop, loop::ParallelOp window_loop,
|
||||||
|
ConversionPatternRewriter* rewriter) const {
|
||||||
|
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||||
|
auto loc = xla_reduce_window_op.getLoc();
|
||||||
|
|
||||||
|
if (!xla_reduce_window_op.window_strides().hasValue()) {
|
||||||
|
xla_reduce_window_op.emitOpError("No window strides specified.");
|
||||||
|
}
|
||||||
|
if (!xla_reduce_window_op.padding().hasValue()) {
|
||||||
|
xla_reduce_window_op.emitOpError("No padding specified.");
|
||||||
|
}
|
||||||
|
if (xla_reduce_window_op.base_dilations().hasValue() ||
|
||||||
|
xla_reduce_window_op.window_dilations().hasValue()) {
|
||||||
|
xla_reduce_window_op.emitRemark(
|
||||||
|
"Lowering to parallel loops does not support `base_dilations` or "
|
||||||
|
"`window_dilations` attributes yet. The attributes will be ignored.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value xla_operand = xla_reduce_window_op.operand();
|
||||||
|
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
|
||||||
|
auto xla_operand_shape = xla_operand_type.getShape();
|
||||||
|
|
||||||
|
auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars());
|
||||||
|
auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars());
|
||||||
|
auto window_strides = xla_reduce_window_op.window_strides().getValue();
|
||||||
|
auto padding = xla_reduce_window_op.padding().getValue();
|
||||||
|
|
||||||
|
SmallVector<Value, 2> operand_indices;
|
||||||
|
// `in_bounds` is false when the element in the reduce window is in the
|
||||||
|
// padding area, true otherwise.
|
||||||
|
Value in_bounds = rewriter->create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter->getI1Type(),
|
||||||
|
rewriter->getIntegerAttr(rewriter->getI1Type(), 1));
|
||||||
|
for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) {
|
||||||
|
auto stride = window_strides.getValue<llvm::APInt>(i);
|
||||||
|
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
|
||||||
|
|
||||||
|
Value stride_val =
|
||||||
|
rewriter->create<ConstantIndexOp>(loc, stride.getSExtValue());
|
||||||
|
Value pad_low_val =
|
||||||
|
rewriter->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
|
||||||
|
|
||||||
|
Value center = rewriter->create<MulIOp>(loc, output_ivs[i], stride_val);
|
||||||
|
Value offset = rewriter->create<SubIOp>(loc, window_ivs[i], pad_low_val);
|
||||||
|
Value index = rewriter->create<AddIOp>(loc, center, offset);
|
||||||
|
operand_indices.push_back(index);
|
||||||
|
Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i,
|
||||||
|
xla_operand_shape[i], rewriter);
|
||||||
|
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
|
||||||
|
// the pad and then we have to use the neutral element for reduction.
|
||||||
|
// Equivalently, it can be computed as the unsigned comparison index_i <
|
||||||
|
// shape_i, since a negative value wraps to a large positive value.
|
||||||
|
in_bounds = rewriter->create<mlir::AndOp>(
|
||||||
|
loc, in_bounds,
|
||||||
|
rewriter->create<CmpIOp>(loc, CmpIPredicate::ult, index,
|
||||||
|
upper_bound));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto elem_or_init =
|
||||||
|
rewriter->create<loop::IfOp>(loc, xla_operand_type.getElementType(),
|
||||||
|
in_bounds, /*withElseRegion=*/true);
|
||||||
|
|
||||||
|
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
|
||||||
|
Value elem = then_builder.create<mlir::LoadOp>(
|
||||||
|
loc, xla_reduce_window_op.operand(), operand_indices);
|
||||||
|
then_builder.create<loop::YieldOp>(loc, elem);
|
||||||
|
|
||||||
|
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
|
||||||
|
else_builder.create<loop::YieldOp>(loc, *window_loop.initVals().begin());
|
||||||
|
|
||||||
|
return rewriter->create<loop::ReduceOp>(loc,
|
||||||
|
*elem_or_init.results().begin());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -217,12 +423,14 @@ struct LhloLegalizeToParallelLoops
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<ReduceOpConverter>(func.getContext());
|
patterns.insert<ReduceOpConverter, ReduceWindowOpConverter>(
|
||||||
|
func.getContext());
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
loop::LoopOpsDialect, XlaLhloDialect>();
|
loop::LoopOpsDialect, XlaLhloDialect>();
|
||||||
target.addIllegalOp<xla_lhlo::ReduceOp>();
|
target.addIllegalOp<xla_lhlo::ReduceOp>();
|
||||||
|
target.addIllegalOp<xla_lhlo::ReduceWindowOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|
Loading…
Reference in New Issue