[XLA][MLIR] Lower xla_lhlo.ReduceWindowOp to parallel loops.

PiperOrigin-RevId: 302892920
Change-Id: Idefaa2f01979748cb30c80820823373851393b34
This commit is contained in:
Alexander Belyaev 2020-03-25 08:18:38 -07:00 committed by TensorFlower Gardener
parent b857f0aaec
commit a8aa33c45e
2 changed files with 332 additions and 50 deletions

View File

@ -125,3 +125,77 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: } // CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: loop.yield // CHECK: loop.yield
// -----
func @reduce_window(%arg: memref<112x112xf32>,
%init: memref<f32>,
%result: memref<56x56xf32>) {
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
window_strides = dense<[2, 2]> : tensor<2xi64>
} : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
return
}
// CHECK-LABEL: func @reduce_window(
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant 1 : i1
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel
// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]])
// CHECK-SAME: init ([[INIT]]) -> f32 {
// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index
// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]]
// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]]
// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index
// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]]
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]]
// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) {
// CHECK: [[OPERAND_ELEM:%.*]] =
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
// CHECK: loop.yield [[OPERAND_ELEM]] : f32
// CHECK: } else {
// CHECK: loop.yield [[INIT]] : f32
// CHECK: }
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: loop.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: loop.yield
// CHECK: }
// CHECK: return
// CHECK: }

View File

@ -29,6 +29,50 @@ namespace mlir {
namespace xla_lhlo { namespace xla_lhlo {
namespace { namespace {
// Converts a block with LHLO ops and with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// into a reduction operator of loop.reduce by doing buffer allocation for
// scalar arguments and the result of `loop.reduce` to make it compatible with
// LHLO ops.
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
Block* lhlo_block,
ConversionPatternRewriter* rewriter) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
rewriter->setInsertionPointToStart(&loop_reduce_op_body);
// Allocate buffers to hold arguments of reduction operator block to stay
// compatible with the LHLO dialect ops in the reduction body.
Value elem_arg = lhlo_block->getArgument(0);
Value elem_buf =
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
Value acc_arg = lhlo_block->getArgument(1);
Value acc_buf =
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(),
ValueRange{elem_buf, acc_buf, acc_buf});
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = rewriter->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
}
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
// extract dimension at runtime.
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
size_t dim_index, int64_t dim,
ConversionPatternRewriter* rewriter) {
return dim == ShapedType::kDynamicSize
? rewriter->create<DimOp>(loc, shaped_value, dim_index).getResult()
: rewriter->create<ConstantIndexOp>(loc, dim);
}
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. // Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are // The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
@ -42,7 +86,7 @@ namespace {
// } ) {dimensions = dense<[1]> : tensor<1xi64>} // } ) {dimensions = dense<[1]> : tensor<1xi64>}
// : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> () // : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
// //
// is converted into: // is roughly converted into:
// //
// %init = load %init_buf[] : memref<f32> // %init = load %init_buf[] : memref<f32>
// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
@ -67,15 +111,15 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern; using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> args, xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure(); if (xla_reduce_op.out().size() != 1) return failure();
loop::ReduceOp reduce_op = loop::ReduceOp reduce_op =
CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter); CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
ConvertReductionOperator(xla_reduce_op, ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
&reduce_op.reductionOperator().front(), &rewriter); &xla_reduce_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_op, llvm::None); rewriter.replaceOp(xla_reduce_op, llvm::None);
return success(); return success();
} }
@ -100,8 +144,8 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// } : f32 // } : f32
// loop.yield // loop.yield
// } // }
loop::ReduceOp CreateParallelLoopsWithReduceOp( loop::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> args, xla_lhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc(); auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims; DenseSet<int> reducing_dims;
@ -114,20 +158,13 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step; SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step; SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
auto operand_shape = operand.getType().cast<MemRefType>().getShape(); auto operand_shape = operand.getType().cast<MemRefType>().getShape();
Type index_type = rewriter->getIndexType();
for (auto dim : llvm::enumerate(operand_shape)) { for (auto dim : llvm::enumerate(operand_shape)) {
const bool is_reducing_dim = reducing_dims.count(dim.index()); const bool is_reducing_dim = reducing_dims.count(dim.index());
Value ub = Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
dim.value() == ShapedType::kDynamicSize rewriter);
? rewriter->create<DimOp>(loc, operand, dim.index()).getResult() Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
: rewriter->create<mlir::ConstantOp>( Value step = rewriter->create<ConstantIndexOp>(loc, 1);
loc, index_type,
rewriter->getIntegerAttr(index_type, dim.value()));
Value lb = rewriter->create<mlir::ConstantOp>(
loc, index_type, rewriter->getIntegerAttr(index_type, 0));
Value step = rewriter->create<mlir::ConstantOp>(
loc, index_type, rewriter->getIntegerAttr(index_type, 1));
(is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb); (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
(is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub); (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
(is_reducing_dim ? reduce_step : parallel_step).push_back(step); (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
@ -153,8 +190,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
out_indices.push_back(iv); out_indices.push_back(iv);
} }
} else { } else {
out_indices.push_back(rewriter->create<mlir::ConstantOp>( out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
loc, index_type, rewriter->getIntegerAttr(index_type, 0)));
} }
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices); rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
@ -175,39 +211,209 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
loc, *xla_reduce_op.operands().begin(), indices); loc, *xla_reduce_op.operands().begin(), indices);
return rewriter->create<loop::ReduceOp>(loc, elem); return rewriter->create<loop::ReduceOp>(loc, elem);
} }
};
// Converts `xla_lhlo.reduce` reduction operator into `loop.reduce` op by // Pseudocode:
// doing buffer allocation for scalar arguments and the result of // for each index O in output
// `loop.reduce` to make it compatible with LHLO ops. // accumulator = neutral_value
void ConvertReductionOperator(xla_lhlo::ReduceOp xla_reduce_op, // in_bounds = true
Block* loop_reduce_op_body, // for each index W in window
ConversionPatternRewriter* rewriter) const { // for each dimension i from 0 to rank - 1
rewriter->setInsertionPointToStart(loop_reduce_op_body); // index = O[i] * stride[i] + W[i] - pad_low[i]
// in_bounds = inbounds && (index `ult` shape[i])
// I[i] = index
// if (in_bounds)
// value = input[I]
// else
// value = neutral_value
// accumulator = reduction_operator(output[O], value)
// output[O] = accumulator
//
// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a
// loop::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
// reduction windows and `ReduceOp` contains the reduction operator.
//
// Example:
//
// func @reduce_window(%arg: memref<112x112xf32>,
// %init: memref<f32>,
// %result: memref<56x56xf32>) {
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// }) {
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
// window_strides = dense<[2, 2]> : tensor<2xi64>
// } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
// return
// }
//
// is roughly converted into:
//
// %neutral_elem = load %init_buf[] : memref<f32>
// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
// %result = loop.parallel (%iw, %jw) = (%c0, %c0)
// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
// %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
// %elem = load %operand[%computed_i, %computed_j]
// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
// loop.reduce(%elem_to_reduce) : f32 {
// ^bb0(%arg7: f32, %arg8: f32):
// <LHLO ops>
// }
// loop.yield
// }
// store %result, %output_buffer[%i, %j] : memref<56x56xf32>
// loop.yield
// }
// return
// }
class ReduceWindowOpConverter
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
// Allocate buffers to hold arguments of reduction operator block to stay LogicalResult matchAndRewrite(
// compatible with the LHLO dialect ops in the reduction body. xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
auto loc = xla_reduce_op.getLoc(); ConversionPatternRewriter& rewriter) const final {
Value elem_arg = xla_reduce_op.body().front().getArgument(0); loop::ParallelOp output_loop, window_loop;
Value elem_buf = std::tie(output_loop, window_loop) =
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>()); CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
rewriter->create<StoreOp>(loc, loop_reduce_op_body->getArgument(0), &rewriter);
elem_buf);
Value acc_arg = xla_reduce_op.body().front().getArgument(1);
Value acc_buf =
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body->getArgument(1),
acc_buf);
// Clone the ops from `xla_lhlo.reduce` into reduction operator block. loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
BlockAndValueMapping mapping; xla_reduce_window_op, output_loop, window_loop, &rewriter);
mapping.map(xla_reduce_op.body().front().getArguments(),
ValueRange{elem_buf, acc_buf, acc_buf}); ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
for (auto& nested : xla_reduce_op.body().front().without_terminator()) { &xla_reduce_window_op.body().front(), &rewriter);
auto clone = rewriter->clone(nested, mapping); rewriter.replaceOp(xla_reduce_window_op, llvm::None);
mapping.map(nested.getResults(), clone->getResults()); return success();
} }
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
rewriter->create<loop::ReduceReturnOp>(loc, acc_result); private:
std::pair<loop::ParallelOp, loop::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc();
Value init_value =
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
// Create an outer parallel loop that spans the output of ReduceWindowOp.
Value xla_output = xla_reduce_window_op.out();
auto output_shape = xla_output.getType().cast<MemRefType>().getShape();
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
for (auto dim : llvm::enumerate(output_shape)) {
parallel_upper.push_back(GetStaticOrDynamicDim(
loc, xla_output, dim.index(), dim.value(), rewriter));
parallel_lower.push_back(zero);
parallel_step.push_back(one);
}
auto output_loop = rewriter->create<loop::ParallelOp>(
loc, parallel_lower, parallel_upper, parallel_step);
// Create a nested loop that traverses the window.
rewriter->setInsertionPointToStart(output_loop.getBody());
SmallVector<Value, 2> window_lower, window_upper, window_step;
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
window_step.push_back(one);
window_lower.push_back(zero);
window_upper.push_back(
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
}
auto window_loop = rewriter->create<loop::ParallelOp>(
loc, window_lower, window_upper, window_step, init_value);
Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(
loc, reduction_result, xla_output,
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
return std::make_pair(output_loop, window_loop);
}
loop::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
loop::ParallelOp output_loop, loop::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
if (!xla_reduce_window_op.window_strides().hasValue()) {
xla_reduce_window_op.emitOpError("No window strides specified.");
}
if (!xla_reduce_window_op.padding().hasValue()) {
xla_reduce_window_op.emitOpError("No padding specified.");
}
if (xla_reduce_window_op.base_dilations().hasValue() ||
xla_reduce_window_op.window_dilations().hasValue()) {
xla_reduce_window_op.emitRemark(
"Lowering to parallel loops does not support `base_dilations` or "
"`window_dilations` attributes yet. The attributes will be ignored.");
}
Value xla_operand = xla_reduce_window_op.operand();
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
auto xla_operand_shape = xla_operand_type.getShape();
auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars());
auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars());
auto window_strides = xla_reduce_window_op.window_strides().getValue();
auto padding = xla_reduce_window_op.padding().getValue();
SmallVector<Value, 2> operand_indices;
// `in_bounds` is false when the element in the reduce window is in the
// padding area, true otherwise.
Value in_bounds = rewriter->create<mlir::ConstantOp>(
loc, rewriter->getI1Type(),
rewriter->getIntegerAttr(rewriter->getI1Type(), 1));
for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) {
auto stride = window_strides.getValue<llvm::APInt>(i);
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
Value stride_val =
rewriter->create<ConstantIndexOp>(loc, stride.getSExtValue());
Value pad_low_val =
rewriter->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
Value center = rewriter->create<MulIOp>(loc, output_ivs[i], stride_val);
Value offset = rewriter->create<SubIOp>(loc, window_ivs[i], pad_low_val);
Value index = rewriter->create<AddIOp>(loc, center, offset);
operand_indices.push_back(index);
Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i,
xla_operand_shape[i], rewriter);
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
// the pad and then we have to use the neutral element for reduction.
// Equivalently, it can be computed as the unsigned comparison index_i <
// shape_i, since a negative value wraps to a large positive value.
in_bounds = rewriter->create<mlir::AndOp>(
loc, in_bounds,
rewriter->create<CmpIOp>(loc, CmpIPredicate::ult, index,
upper_bound));
}
auto elem_or_init =
rewriter->create<loop::IfOp>(loc, xla_operand_type.getElementType(),
in_bounds, /*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), operand_indices);
then_builder.create<loop::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
else_builder.create<loop::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<loop::ReduceOp>(loc,
*elem_or_init.results().begin());
} }
}; };
@ -217,12 +423,14 @@ struct LhloLegalizeToParallelLoops
auto func = getFunction(); auto func = getFunction();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ReduceOpConverter>(func.getContext()); patterns.insert<ReduceOpConverter, ReduceWindowOpConverter>(
func.getContext());
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
loop::LoopOpsDialect, XlaLhloDialect>(); loop::LoopOpsDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp>(); target.addIllegalOp<xla_lhlo::ReduceOp>();
target.addIllegalOp<xla_lhlo::ReduceWindowOp>();
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure(); signalPassFailure();