Merge pull request #36242 from dfki-ehna:buffer_assignment
PiperOrigin-RevId: 292238439 Change-Id: I856b7f2ce788d4ba65bc7a16ba31aba8fdf6c4db
This commit is contained in:
commit
21548e548d
@ -11,6 +11,44 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_op
|
||||
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MAX_RESULT]], %arg2)
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
%1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
|
||||
%2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.add"(%arg0, %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
%3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.min"(%arg0, %arg1, %[[MIN_RESULT]])
|
||||
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.sub"(%arg1, %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %arg2)
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fusion
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
@ -120,7 +158,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK: xla_lhlo.terminator
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -39,54 +39,49 @@ namespace {
|
||||
|
||||
constexpr StringRef kTempBufferAttr = "temp";
|
||||
|
||||
Value GetTensorStoreOrReturnMemRef(Value value) {
|
||||
/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
|
||||
Operation* FindInsertionPointForCopy(Value value) {
|
||||
for (const auto& user : value.getUsers()) {
|
||||
if (auto dealloc = dyn_cast<DeallocOp>(user)) {
|
||||
return user;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value GetTensorStore(Value value) {
|
||||
for (const auto& user : value.getUsers()) {
|
||||
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
|
||||
if (tensor_store.getOperand(0) == value) {
|
||||
return tensor_store.getOperand(1);
|
||||
}
|
||||
}
|
||||
if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
|
||||
if (return_op.getOperand(0) == value) {
|
||||
auto block = return_op.getOperation()->getBlock();
|
||||
return *block->args_rbegin();
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation* GetLastUse(Value value) {
|
||||
Operation* last = value.getDefiningOp();
|
||||
for (auto& user : value.getUses()) {
|
||||
Operation* user_op = user.getOwner();
|
||||
if (!user_op->isBeforeInBlock(last)) {
|
||||
last = user_op;
|
||||
}
|
||||
}
|
||||
return last;
|
||||
}
|
||||
|
||||
Value InsertAllocAndDealloc(Location loc, Value result,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||
if (!result_type || !result_type.hasStaticShape()) {
|
||||
emitError(loc,
|
||||
"tensor to buffer conversion expects statically shaped results");
|
||||
result.getDefiningOp()->emitOpError()
|
||||
<< "tensor to buffer conversion expects statically shaped results";
|
||||
}
|
||||
auto memref_type =
|
||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||
|
||||
Operation* last = GetLastUse(result);
|
||||
|
||||
Operation* op = result.getDefiningOp();
|
||||
auto block = op->getBlock();
|
||||
|
||||
OpBuilder allocBuilder(op);
|
||||
allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning
|
||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
|
||||
|
||||
alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
|
||||
|
||||
allocBuilder.setInsertionPoint(op->getBlock(),
|
||||
std::next(Block::iterator(last)));
|
||||
allocBuilder.setInsertionPoint(block, std::prev(block->end()));
|
||||
allocBuilder.create<DeallocOp>(loc, alloc);
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
@ -95,7 +90,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
|
||||
/// function to store that values held in the tensor.
|
||||
Value GetBufferForResultValue(Location loc, Value result,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
|
||||
if (auto existing_memref = GetTensorStore(result)) {
|
||||
return existing_memref;
|
||||
}
|
||||
return InsertAllocAndDealloc(loc, result, rewriter);
|
||||
@ -110,11 +105,6 @@ class HloToLhloOpConverter : public ConversionPattern {
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (op->getParentRegion()->getBlocks().size() != 1) {
|
||||
emitError(op->getLoc(),
|
||||
"tensor to buffer conversion expects a single block in the "
|
||||
"region containing the operation");
|
||||
}
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : original_results) {
|
||||
@ -128,7 +118,7 @@ class HloToLhloOpConverter : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLHloReduceConverter
|
||||
struct HloToLHloReduceOpConverter
|
||||
: public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -140,9 +130,9 @@ struct HloToLHloReduceConverter
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (op.getNumResults() != 1) return matchFailure();
|
||||
if (op.getParentRegion()->getBlocks().size() != 1) {
|
||||
emitError(loc,
|
||||
"tensor to buffer conversion expects a single block in the "
|
||||
"region containing the operation");
|
||||
op.emitOpError() << "tensor to buffer conversion expects a single block "
|
||||
"in the region containing the operation";
|
||||
return matchFailure();
|
||||
}
|
||||
const auto& original_results = op.getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
@ -183,11 +173,10 @@ struct HloToLHloReduceConverter
|
||||
}
|
||||
};
|
||||
|
||||
class HloToLhloTensorLoadConverter : public ConversionPattern {
|
||||
class HloToLhloTensorLoadOpConverter : public ConversionPattern {
|
||||
public:
|
||||
explicit HloToLhloTensorLoadConverter(MLIRContext* context)
|
||||
explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
|
||||
: ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
@ -197,9 +186,9 @@ class HloToLhloTensorLoadConverter : public ConversionPattern {
|
||||
};
|
||||
|
||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||
class HloToLhloTensorStoreConverter : public ConversionPattern {
|
||||
class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
public:
|
||||
explicit HloToLhloTensorStoreConverter(MLIRContext* context)
|
||||
explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
|
||||
: ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
@ -210,19 +199,6 @@ class HloToLhloTensorStoreConverter : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||
class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
xla_hlo::ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
// buffers if necessary.
|
||||
//
|
||||
@ -263,26 +239,147 @@ class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
struct HloLegalizeToLhlo : public FunctionPass<HloLegalizeToLhlo> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
|
||||
//
|
||||
// FuncOp signature conversion example:
|
||||
//
|
||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32>
|
||||
// return %1 : tensor<4xf32>
|
||||
// }
|
||||
//
|
||||
// Transformed function with an extra argument for the result. The types have
|
||||
// been converted from tensor to memref.
|
||||
//
|
||||
// func @func_op(%arg0: memref<4xf32>,
|
||||
// %arg1: memref<4xf32>,
|
||||
// %arg2: memref<4xf32>) {
|
||||
// %0 = alloc() {temp = true} : memref<4xf32>
|
||||
// %1 = alloc() {temp = true} : memref<4xf32>
|
||||
// "xla_lhlo.max"(%arg0, %arg1, %1) {name = "maximum.47"} :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.add"(%arg0, %1, %0) {name = "maximum.47"} :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// dealloc %1 : memref<4xf32>
|
||||
// "xla_lhlo.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// dealloc %0 : memref<4xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// }
|
||||
|
||||
auto func = getFunction();
|
||||
populateHLOToLHLOConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
|
||||
void runOnModule() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addIllegalOp<mlir::ReturnOp>();
|
||||
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
auto inputs = op.getType().getInputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(),
|
||||
[](Type input) { return input.isa<MemRefType>(); });
|
||||
});
|
||||
|
||||
auto module = getModule();
|
||||
populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
|
||||
|
||||
if (failed(applyFullConversion(module, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Type ConvertType(Type t) {
|
||||
if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
|
||||
return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
|
||||
/// results are converted to memrefs and appended to the argument list.
|
||||
class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (funcOp.getBody().getBlocks().size() > 1) {
|
||||
funcOp.emitOpError() << "tensor to buffer conversion expects a single "
|
||||
"block in the region containing the operation";
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
auto funcType = funcOp.getType();
|
||||
|
||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
||||
for (auto argType : llvm::enumerate(funcType.getInputs())) {
|
||||
conversion.addInputs(argType.index(), ConvertType(argType.value()));
|
||||
}
|
||||
for (auto resType : funcType.getResults()) {
|
||||
conversion.addInputs(ConvertType(resType));
|
||||
}
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(
|
||||
rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
|
||||
/// result to the corresponding buffer argument.
|
||||
class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
mlir::ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto numReturnValues = returnOp.getNumOperands();
|
||||
auto funcOp = returnOp.getParentOfType<FuncOp>();
|
||||
auto numFuncArgs = funcOp.getNumArguments();
|
||||
auto loc = returnOp.getLoc();
|
||||
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
|
||||
auto dstBuffer = funcOp.getArgument(returnArgNumber);
|
||||
if (dstBuffer == operand.value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dealloc = FindInsertionPointForCopy(operand.value());
|
||||
|
||||
if (dealloc == nullptr) {
|
||||
returnOp.emitOpError()
|
||||
<< "Missing dealloc for operand " << operand.index();
|
||||
return matchFailure();
|
||||
}
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(dealloc);
|
||||
rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
|
||||
funcOp.getArgument(returnArgNumber));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
patterns->insert<
|
||||
HloToLHloReduceOpConverter,
|
||||
HloToLhloFuncOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
|
||||
@ -306,13 +403,14 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
|
||||
HloToLHloReduceConverter, HloToLhloReturnConverter,
|
||||
HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter,
|
||||
StdToLhloReturnOpConverter
|
||||
>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToLhloPass() {
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
|
||||
return absl::make_unique<HloLegalizeToLhlo>();
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToStdPass();
|
||||
|
||||
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
// buffers if necessary.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToLhloPass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user