Merge pull request #43069 from dfki-ehna:fix_hlo_legalize_to_lhlo_test
PiperOrigin-RevId: 330907602 Change-Id: Ieb31888480f36974ffac34c2aaca4431074d77e8
This commit is contained in:
commit
622d9a212a
@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
|
||||
// is provided by mlir buffer assignment, so use the pattern from there.
|
||||
using HloToLhloReturnOpConverter =
|
||||
BufferAssignmentReturnOpConverter<mhlo::ReturnOp, lmhlo::TerminatorOp,
|
||||
lmhlo::CopyOp>;
|
||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
|
||||
struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
auto& entry_block = op.getParentRegion()->front();
|
||||
auto num_arguments = entry_block.getNumArguments();
|
||||
if (operands.size() > num_arguments) {
|
||||
return op.emitError(
|
||||
"The number of operands that need Copy operations is more "
|
||||
"than the number of target function arguments.");
|
||||
}
|
||||
|
||||
// The index of the first output block argument.
|
||||
auto dest_arg_idx = num_arguments - operands.size();
|
||||
|
||||
// Create a lmhlo.copy for each operand of mhlo.return.
|
||||
for (Value operand : operands) {
|
||||
rewriter.create<lmhlo::CopyOp>(loc, operand,
|
||||
entry_block.getArgument(dest_arg_idx));
|
||||
++dest_arg_idx;
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class HloToLhloTensorLoadOpConverter
|
||||
: public BaseOpConversion<mlir::TensorLoadOp> {
|
||||
@ -429,6 +454,13 @@ struct HloLegalizeToLhlo
|
||||
isMemRefType);
|
||||
});
|
||||
|
||||
auto kind = results_escape_function
|
||||
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||
: BufferAssignmentTypeConverter::AppendToArgumentsList;
|
||||
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
||||
kind);
|
||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
||||
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
|
||||
|
@ -1,8 +1,5 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
|
||||
// TODO(herhut): unbreak the test after upstream API changes.
|
||||
// XFAIL: *
|
||||
|
||||
|
||||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user