diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index 3317d24d820..1e375e142f7 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -125,3 +125,77 @@ func @dynamic_reduce(%arg: memref, // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: loop.yield + +// ----- + +func @reduce_window(%arg: memref<112x112xf32>, + %init: memref, + %result: memref<56x56xf32>) { + "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.maximum"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }) { + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + window_dimensions = dense<[3, 3]> : tensor<2xi64>, + window_strides = dense<[2, 2]> : tensor<2xi64> + } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () + return +} +// CHECK-LABEL: func @reduce_window( +// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) { +// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant 1 : i1 +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C2:%.*]] = constant 2 : index +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C56:%.*]] = constant 56 : index +// CHECK-DAG: [[C112:%.*]] = constant 112 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel +// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) +// CHECK-SAME: init ([[INIT]]) -> f32 { + +// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index +// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index +// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index +// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]] +// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]] + +// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index +// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index +// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index +// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] +// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]] + +// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[OPERAND_ELEM:%.*]] = +// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] +// CHECK: loop.yield [[OPERAND_ELEM]] : f32 +// CHECK: } else { +// CHECK: loop.yield [[INIT]] : f32 +// CHECK: } + +// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: loop.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: loop.yield +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index f2ae7227a23..1250db08ee5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -29,6 +29,50 @@ namespace mlir { namespace xla_lhlo { namespace { +// Converts a block with LHLO ops and with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// into a reduction operator of loop.reduce by doing buffer allocation for +// scalar arguments and the result of `loop.reduce` to make it compatible with +// LHLO ops. +void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, + Block* lhlo_block, + ConversionPatternRewriter* rewriter) { + Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); + rewriter->setInsertionPointToStart(&loop_reduce_op_body); + + // Allocate buffers to hold arguments of reduction operator block to stay + // compatible with the LHLO dialect ops in the reduction body. + Value elem_arg = lhlo_block->getArgument(0); + Value elem_buf = + rewriter->create(loc, elem_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body.getArgument(0), elem_buf); + Value acc_arg = lhlo_block->getArgument(1); + Value acc_buf = + rewriter->create(loc, acc_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body.getArgument(1), acc_buf); + + // Clone the ops from `xla_lhlo.reduce` into reduction operator block. + BlockAndValueMapping mapping; + mapping.map(lhlo_block->getArguments(), + ValueRange{elem_buf, acc_buf, acc_buf}); + for (auto& nested : lhlo_block->without_terminator()) { + auto clone = rewriter->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + Value acc_result = rewriter->create(loc, acc_buf); + rewriter->create(loc, acc_result); +} + +// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to +// extract dimension at runtime. +Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, + size_t dim_index, int64_t dim, + ConversionPatternRewriter* rewriter) { + return dim == ShapedType::kDynamicSize + ? rewriter->create(loc, shaped_value, dim_index).getResult() + : rewriter->create(loc, dim); +} + // Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` @@ -42,7 +86,7 @@ namespace { // } ) {dimensions = dense<[1]> : tensor<1xi64>} // : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () // -// is converted into: +// is roughly converted into: // // %init = load %init_buf[] : memref // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { @@ -67,15 +111,15 @@ class ReduceOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); loop::ReduceOp reduce_op = - CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter); - ConvertReductionOperator(xla_reduce_op, - &reduce_op.reductionOperator().front(), &rewriter); + CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); + ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, + &xla_reduce_op.body().front(), &rewriter); rewriter.replaceOp(xla_reduce_op, llvm::None); return success(); } @@ -100,8 +144,8 @@ class ReduceOpConverter : public OpConversionPattern { // } : f32 // loop.yield // } - loop::ReduceOp CreateParallelLoopsWithReduceOp( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + loop::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); DenseSet reducing_dims; @@ -114,20 +158,13 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector parallel_lower, parallel_upper, parallel_step; SmallVector reduce_lower, reduce_upper, reduce_step; auto operand_shape = operand.getType().cast().getShape(); - Type index_type = rewriter->getIndexType(); for (auto dim : llvm::enumerate(operand_shape)) { const bool is_reducing_dim = reducing_dims.count(dim.index()); - Value ub = - dim.value() == ShapedType::kDynamicSize - ? rewriter->create(loc, operand, dim.index()).getResult() - : rewriter->create( - loc, index_type, - rewriter->getIntegerAttr(index_type, dim.value())); - Value lb = rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 0)); - Value step = rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 1)); + Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(), + rewriter); + Value lb = rewriter->create(loc, 0); + Value step = rewriter->create(loc, 1); (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb); (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub); (is_reducing_dim ? reduce_step : parallel_step).push_back(step); @@ -153,8 +190,7 @@ class ReduceOpConverter : public OpConversionPattern { out_indices.push_back(iv); } } else { - out_indices.push_back(rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 0))); + out_indices.push_back(rewriter->create(loc, 0)); } rewriter->create(loc, reduction_result, out, out_indices); @@ -175,39 +211,209 @@ class ReduceOpConverter : public OpConversionPattern { loc, *xla_reduce_op.operands().begin(), indices); return rewriter->create(loc, elem); } +}; - // Converts `xla_lhlo.reduce` reduction operator into `loop.reduce` op by - // doing buffer allocation for scalar arguments and the result of - // `loop.reduce` to make it compatible with LHLO ops. - void ConvertReductionOperator(xla_lhlo::ReduceOp xla_reduce_op, - Block* loop_reduce_op_body, - ConversionPatternRewriter* rewriter) const { - rewriter->setInsertionPointToStart(loop_reduce_op_body); +// Pseudocode: +// for each index O in output +// accumulator = neutral_value +// in_bounds = true +// for each index W in window +// for each dimension i from 0 to rank - 1 +// index = O[i] * stride[i] + W[i] - pad_low[i] +// in_bounds = inbounds && (index `ult` shape[i]) +// I[i] = index +// if (in_bounds) +// value = input[I] +// else +// value = neutral_value +// accumulator = reduction_operator(output[O], value) +// output[O] = accumulator +// +// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a +// loop::ReduceOp. +// The outper `ParallelOp` refers to the parallel loops that traverese output +// buffer. The inner `ParalleOp` refers to the reduction loops that traverse +// reduction windows and `ReduceOp` contains the reduction operator. +// +// Example: +// +// func @reduce_window(%arg: memref<112x112xf32>, +// %init: memref, +// %result: memref<56x56xf32>) { +// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// ^bb0(%lhs: memref, %rhs: memref, %res: memref): +// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// : (memref, memref, memref) -> () +// "xla_lhlo.terminator"() : () -> () +// }) { +// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, +// window_dimensions = dense<[3, 3]> : tensor<2xi64>, +// window_strides = dense<[2, 2]> : tensor<2xi64> +// } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () +// return +// } +// +// is roughly converted into: +// +// %neutral_elem = load %init_buf[] : memref +// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = loop.parallel (%iw, %jw) = (%c0, %c0) +// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { +// %in_bounds = +// %elem = load %operand[%computed_i, %computed_j] +// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 +// loop.reduce(%elem_to_reduce) : f32 { +// ^bb0(%arg7: f32, %arg8: f32): +// +// } +// loop.yield +// } +// store %result, %output_buffer[%i, %j] : memref<56x56xf32> +// loop.yield +// } +// return +// } +class ReduceWindowOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; - // Allocate buffers to hold arguments of reduction operator block to stay - // compatible with the LHLO dialect ops in the reduction body. - auto loc = xla_reduce_op.getLoc(); - Value elem_arg = xla_reduce_op.body().front().getArgument(0); - Value elem_buf = - rewriter->create(loc, elem_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body->getArgument(0), - elem_buf); - Value acc_arg = xla_reduce_op.body().front().getArgument(1); - Value acc_buf = - rewriter->create(loc, acc_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body->getArgument(1), - acc_buf); + LogicalResult matchAndRewrite( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + loop::ParallelOp output_loop, window_loop; + std::tie(output_loop, window_loop) = + CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, + &rewriter); - // Clone the ops from `xla_lhlo.reduce` into reduction operator block. - BlockAndValueMapping mapping; - mapping.map(xla_reduce_op.body().front().getArguments(), - ValueRange{elem_buf, acc_buf, acc_buf}); - for (auto& nested : xla_reduce_op.body().front().without_terminator()) { - auto clone = rewriter->clone(nested, mapping); - mapping.map(nested.getResults(), clone->getResults()); + loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + xla_reduce_window_op, output_loop, window_loop, &rewriter); + + ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, + &xla_reduce_window_op.body().front(), &rewriter); + rewriter.replaceOp(xla_reduce_window_op, llvm::None); + return success(); + } + + private: + std::pair + CreateParallelLoopsToTraverseOutputAndWindow( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + ConversionPatternRewriter* rewriter) const { + auto loc = xla_reduce_window_op.getLoc(); + Value init_value = + rewriter->create(loc, xla_reduce_window_op.init_value()); + + Value zero = rewriter->create(loc, 0); + Value one = rewriter->create(loc, 1); + + // Create an outer parallel loop that spans the output of ReduceWindowOp. + Value xla_output = xla_reduce_window_op.out(); + auto output_shape = xla_output.getType().cast().getShape(); + SmallVector parallel_lower, parallel_upper, parallel_step; + for (auto dim : llvm::enumerate(output_shape)) { + parallel_upper.push_back(GetStaticOrDynamicDim( + loc, xla_output, dim.index(), dim.value(), rewriter)); + parallel_lower.push_back(zero); + parallel_step.push_back(one); } - Value acc_result = rewriter->create(loc, acc_buf); - rewriter->create(loc, acc_result); + auto output_loop = rewriter->create( + loc, parallel_lower, parallel_upper, parallel_step); + + // Create a nested loop that traverses the window. + rewriter->setInsertionPointToStart(output_loop.getBody()); + SmallVector window_lower, window_upper, window_step; + for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { + window_step.push_back(one); + window_lower.push_back(zero); + window_upper.push_back( + rewriter->create(loc, window_dim.getSExtValue())); + } + auto window_loop = rewriter->create( + loc, window_lower, window_upper, window_step, init_value); + + Value reduction_result = *window_loop.getResults().begin(); + auto output_ivs = output_loop.getInductionVars(); + rewriter->create( + loc, reduction_result, xla_output, + llvm::makeArrayRef(output_ivs.begin(), output_ivs.end())); + return std::make_pair(output_loop, window_loop); + } + + loop::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + loop::ParallelOp output_loop, loop::ParallelOp window_loop, + ConversionPatternRewriter* rewriter) const { + rewriter->setInsertionPointToStart(window_loop.getBody()); + auto loc = xla_reduce_window_op.getLoc(); + + if (!xla_reduce_window_op.window_strides().hasValue()) { + xla_reduce_window_op.emitOpError("No window strides specified."); + } + if (!xla_reduce_window_op.padding().hasValue()) { + xla_reduce_window_op.emitOpError("No padding specified."); + } + if (xla_reduce_window_op.base_dilations().hasValue() || + xla_reduce_window_op.window_dilations().hasValue()) { + xla_reduce_window_op.emitRemark( + "Lowering to parallel loops does not support `base_dilations` or " + "`window_dilations` attributes yet. The attributes will be ignored."); + } + + Value xla_operand = xla_reduce_window_op.operand(); + auto xla_operand_type = xla_operand.getType().cast(); + auto xla_operand_shape = xla_operand_type.getShape(); + + auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars()); + auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars()); + auto window_strides = xla_reduce_window_op.window_strides().getValue(); + auto padding = xla_reduce_window_op.padding().getValue(); + + SmallVector operand_indices; + // `in_bounds` is false when the element in the reduce window is in the + // padding area, true otherwise. + Value in_bounds = rewriter->create( + loc, rewriter->getI1Type(), + rewriter->getIntegerAttr(rewriter->getI1Type(), 1)); + for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) { + auto stride = window_strides.getValue(i); + auto pad_low = padding.getValue({i, 0}); + + Value stride_val = + rewriter->create(loc, stride.getSExtValue()); + Value pad_low_val = + rewriter->create(loc, pad_low.getSExtValue()); + + Value center = rewriter->create(loc, output_ivs[i], stride_val); + Value offset = rewriter->create(loc, window_ivs[i], pad_low_val); + Value index = rewriter->create(loc, center, offset); + operand_indices.push_back(index); + Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i, + xla_operand_shape[i], rewriter); + // We must check whether 0 <= index_i < shape_i, as otherwise we are in + // the pad and then we have to use the neutral element for reduction. + // Equivalently, it can be computed as the unsigned comparison index_i < + // shape_i, since a negative value wraps to a large positive value. + in_bounds = rewriter->create( + loc, in_bounds, + rewriter->create(loc, CmpIPredicate::ult, index, + upper_bound)); + } + + auto elem_or_init = + rewriter->create(loc, xla_operand_type.getElementType(), + in_bounds, /*withElseRegion=*/true); + + OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); + Value elem = then_builder.create( + loc, xla_reduce_window_op.operand(), operand_indices); + then_builder.create(loc, elem); + + OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); + else_builder.create(loc, *window_loop.initVals().begin()); + + return rewriter->create(loc, + *elem_or_init.results().begin()); } }; @@ -217,12 +423,14 @@ struct LhloLegalizeToParallelLoops auto func = getFunction(); OwningRewritePatternList patterns; - patterns.insert(func.getContext()); + patterns.insert( + func.getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure();