Merge pull request #43069 from dfki-ehna:fix_hlo_legalize_to_lhlo_test

PiperOrigin-RevId: 330907602
Change-Id: Ieb31888480f36974ffac34c2aaca4431074d77e8
This commit is contained in:
TensorFlower Gardener 2020-09-10 04:15:03 -07:00
commit 622d9a212a
2 changed files with 37 additions and 8 deletions

View File

@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
}
};
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
// is provided by mlir buffer assignment, so use the pattern from there.
using HloToLhloReturnOpConverter =
BufferAssignmentReturnOpConverter<mhlo::ReturnOp, lmhlo::TerminatorOp,
lmhlo::CopyOp>;
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
public:
using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto& entry_block = op.getParentRegion()->front();
auto num_arguments = entry_block.getNumArguments();
if (operands.size() > num_arguments) {
return op.emitError(
"The number of operands that need Copy operations is more "
"than the number of target function arguments.");
}
// The index of the first output block argument.
auto dest_arg_idx = num_arguments - operands.size();
// Create a lmhlo.copy for each operand of mhlo.return.
for (Value operand : operands) {
rewriter.create<lmhlo::CopyOp>(loc, operand,
entry_block.getArgument(dest_arg_idx));
++dest_arg_idx;
}
rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
return success();
}
};
class HloToLhloTensorLoadOpConverter
: public BaseOpConversion<mlir::TensorLoadOp> {
@ -429,6 +454,13 @@ struct HloLegalizeToLhlo
isMemRefType);
});
auto kind = results_escape_function
? BufferAssignmentTypeConverter::KeepAsFunctionResult
: BufferAssignmentTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,

View File

@ -1,8 +1,5 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
// TODO(herhut): unbreak the test after upstream API changes.
// XFAIL: *
// BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {