diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index 1e375e142f7..ff4f1d940bf 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield @@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] // CHECK: } // CHECK: loop.yield @@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield @@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 2781b955327..806fe5d6f61 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -29,38 +29,49 @@ namespace mlir { namespace xla_lhlo { namespace { +// Clones and adapts the code in `lhlo_block` that works on buffers and has a +// single output buffer to make it compatible with `operands` that have element +// types of the respective buffers. Returns the computed value. +// +// Example. For `operands` with (f32, i32) types and a block with LHLO ops and +// with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// +// +// inserts necessary alloc and store ops to compute and return result that has +// `i1` type. +Value ApplySingleResultLhloCode(Location loc, ValueRange operands, + Block* lhlo_block, OpBuilder* b) { + SmallVector arg_bufs; + for (auto arg_type : lhlo_block->getArgumentTypes()) { + arg_bufs.push_back(b->create(loc, arg_type.cast())); + } + for (auto operand : llvm::enumerate(operands)) { + b->create(loc, operand.value(), arg_bufs[operand.index()]); + } + // Clone the ops from `lhlo_block`. + BlockAndValueMapping mapping; + mapping.map(lhlo_block->getArguments(), arg_bufs); + for (auto& nested : lhlo_block->without_terminator()) { + auto clone = b->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + return b->create(loc, arg_bufs.back()); +} + // Converts a block with LHLO ops and with signature: // ^bb(%lhs: memref, %rhs: memref, %res: memref): // into a reduction operator of loop.reduce by doing buffer allocation for // scalar arguments and the result of `loop.reduce` to make it compatible with // LHLO ops. void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, - Block* lhlo_block, - ConversionPatternRewriter* rewriter) { + Block* lhlo_block, OpBuilder* b) { Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); - rewriter->setInsertionPointToStart(&loop_reduce_op_body); - - // Allocate buffers to hold arguments of reduction operator block to stay - // compatible with the LHLO dialect ops in the reduction body. - Value elem_arg = lhlo_block->getArgument(0); - Value elem_buf = - rewriter->create(loc, elem_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body.getArgument(0), elem_buf); - Value acc_arg = lhlo_block->getArgument(1); - Value acc_buf = - rewriter->create(loc, acc_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body.getArgument(1), acc_buf); - - // Clone the ops from `xla_lhlo.reduce` into reduction operator block. - BlockAndValueMapping mapping; - mapping.map(lhlo_block->getArguments(), - ValueRange{elem_buf, acc_buf, acc_buf}); - for (auto& nested : lhlo_block->without_terminator()) { - auto clone = rewriter->clone(nested, mapping); - mapping.map(nested.getResults(), clone->getResults()); - } - Value acc_result = rewriter->create(loc, acc_buf); - rewriter->create(loc, acc_result); + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(&loop_reduce_op_body); + b->create( + loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), + lhlo_block, b)); } // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to @@ -79,9 +90,8 @@ struct MappedIvs { SmallVector ivs; }; -MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ArrayRef ivs, - ArrayRef window_ivs, - OpBuilder* b) { +MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs, + ValueRange window_ivs, OpBuilder* b) { MappedIvs mapped_ivs; if (!op.window_strides().hasValue()) { @@ -256,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector out_indices; if (outer != nullptr) { out_indices.reserve(outer.getNumLoops()); - for (auto& iv : outer.getInductionVars()) { + for (Value iv : outer.getInductionVars()) { out_indices.push_back(iv); } } else { @@ -268,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern { // Load the element to reduce. SmallVector indices; indices.reserve(operand_shape.size()); - Block::args_iterator outer_ivs_it = - outer ? outer.getInductionVars().begin() : nullptr; - Block::args_iterator inner_ivs_it = inner.getInductionVars().begin(); - for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { - indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ - : *outer_ivs_it++); + + if (outer) { + auto inner_ivs_it = inner.getInductionVars().begin(); + auto outer_ivs_it = outer.getInductionVars().begin(); + for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { + indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ + : *outer_ivs_it++); + } + } else { + indices = ValueRange(inner.getInductionVars()); } rewriter->setInsertionPointToStart(inner.getBody()); @@ -395,9 +409,8 @@ class ReduceWindowOpConverter Value reduction_result = *window_loop.getResults().begin(); auto output_ivs = output_loop.getInductionVars(); - rewriter->create( - loc, reduction_result, xla_output, - llvm::makeArrayRef(output_ivs.begin(), output_ivs.end())); + rewriter->create(loc, reduction_result, xla_output, + ValueRange{output_ivs}); return std::make_pair(output_loop, window_loop); }