Merge pull request #43069 from dfki-ehna:fix_hlo_legalize_to_lhlo_test
PiperOrigin-RevId: 330907602 Change-Id: Ieb31888480f36974ffac34c2aaca4431074d77e8
This commit is contained in:
commit
622d9a212a
@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
|
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
|
||||||
// is provided by mlir buffer assignment, so use the pattern from there.
|
struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
|
||||||
using HloToLhloReturnOpConverter =
|
public:
|
||||||
BufferAssignmentReturnOpConverter<mhlo::ReturnOp, lmhlo::TerminatorOp,
|
using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
|
||||||
lmhlo::CopyOp>;
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ReturnOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto& entry_block = op.getParentRegion()->front();
|
||||||
|
auto num_arguments = entry_block.getNumArguments();
|
||||||
|
if (operands.size() > num_arguments) {
|
||||||
|
return op.emitError(
|
||||||
|
"The number of operands that need Copy operations is more "
|
||||||
|
"than the number of target function arguments.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// The index of the first output block argument.
|
||||||
|
auto dest_arg_idx = num_arguments - operands.size();
|
||||||
|
|
||||||
|
// Create a lmhlo.copy for each operand of mhlo.return.
|
||||||
|
for (Value operand : operands) {
|
||||||
|
rewriter.create<lmhlo::CopyOp>(loc, operand,
|
||||||
|
entry_block.getArgument(dest_arg_idx));
|
||||||
|
++dest_arg_idx;
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class HloToLhloTensorLoadOpConverter
|
class HloToLhloTensorLoadOpConverter
|
||||||
: public BaseOpConversion<mlir::TensorLoadOp> {
|
: public BaseOpConversion<mlir::TensorLoadOp> {
|
||||||
@ -429,6 +454,13 @@ struct HloLegalizeToLhlo
|
|||||||
isMemRefType);
|
isMemRefType);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
auto kind = results_escape_function
|
||||||
|
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||||
|
: BufferAssignmentTypeConverter::AppendToArgumentsList;
|
||||||
|
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
||||||
|
kind);
|
||||||
|
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
||||||
|
|
||||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||||
populateWithBufferAssignmentOpConversionPatterns<
|
populateWithBufferAssignmentOpConversionPatterns<
|
||||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
|
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
|
||||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
|
||||||
// TODO(herhut): unbreak the test after upstream API changes.
|
|
||||||
// XFAIL: *
|
|
||||||
|
|
||||||
|
|
||||||
// BOTH-LABEL: func @attrs
|
// BOTH-LABEL: func @attrs
|
||||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user