[MLIR_GPU] Generalize StoreForwarding and DeadTempBufferRemoval passes.
Now, instead of supporting AllocOp, they use side-effects interface to detect allocations. PiperOrigin-RevId: 336160536 Change-Id: I5356503ed17e7379bb27160b3b6ff1b240e978da
This commit is contained in:
parent
2b31ba7d0a
commit
abaca545db
@ -191,6 +191,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
@ -277,6 +278,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
@ -294,6 +296,7 @@ tf_cc_binary(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
@ -54,6 +55,21 @@ struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename EffectTy>
|
||||
bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) {
|
||||
auto mem_effects_interface =
|
||||
mlir::dyn_cast_or_null<mlir::MemoryEffectOpInterface>(op);
|
||||
if (!mem_effects_interface) {
|
||||
return false;
|
||||
}
|
||||
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
|
||||
mem_effects_interface.getEffects(effects);
|
||||
return llvm::any_of(effects,
|
||||
[op](const mlir::MemoryEffects::EffectInstance& effect) {
|
||||
return mlir::isa<EffectTy>(effect.getEffect());
|
||||
});
|
||||
}
|
||||
|
||||
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
||||
mlir::StoreOp findStore(mlir::Operation* op,
|
||||
std::function<bool(mlir::StoreOp)> matches) {
|
||||
@ -87,10 +103,9 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
||||
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
|
||||
defOp = subviewOp.source().getDefiningOp();
|
||||
}
|
||||
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
|
||||
return allocOp.getOperation();
|
||||
}
|
||||
return nullptr;
|
||||
return HasEffectsOnValue<mlir::MemoryEffects::Allocate>(memref, defOp)
|
||||
? defOp
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
// Retrieves AllocOp from the cache or actually looks for it.
|
||||
@ -101,7 +116,7 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
||||
if (allocOpIt != memrefToAllocOp->end()) {
|
||||
return allocOpIt->second;
|
||||
}
|
||||
auto allocOp = SearchAllocOp(memref);
|
||||
mlir::Operation* allocOp = SearchAllocOp(memref);
|
||||
memrefToAllocOp->insert({memref, allocOp});
|
||||
return allocOp;
|
||||
}
|
||||
@ -169,13 +184,18 @@ struct DeadTempBufferRemovalPass
|
||||
|
||||
void runOnFunction() override {
|
||||
llvm::SmallVector<mlir::Operation*, 8> dead_ops;
|
||||
getFunction().walk([&](mlir::AllocOp allocOp) {
|
||||
if (!operationConsideredDead(allocOp)) {
|
||||
getFunction().walk([&](mlir::Operation* op) {
|
||||
if (op->getNumResults() != 1 ||
|
||||
!HasEffectsOnValue<mlir::MemoryEffects::Allocate>(op->getResult(0),
|
||||
op)) {
|
||||
return;
|
||||
}
|
||||
if (!operationConsideredDead(op)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(herhut): There should be a generic helper for this.
|
||||
recursiveErase(allocOp, &dead_ops);
|
||||
recursiveErase(op, &dead_ops);
|
||||
});
|
||||
for (auto op : dead_ops) {
|
||||
op->erase();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-gpu-opt %s | FileCheck %s
|
||||
// RUN: xla-gpu-opt %s --print-ir-after-all | FileCheck %s
|
||||
HloModule AddMultiply
|
||||
|
||||
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
|
||||
|
@ -6,6 +6,18 @@ func @dead() {
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c12 = constant 12 : index
|
||||
// CHECK-NOT: store
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dead_alloca
|
||||
func @dead_alloca() {
|
||||
// CHECK-NOT: alloca
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c12 = constant 12 : index
|
||||
// CHECK-NOT: store
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
return
|
||||
}
|
||||
|
@ -13,6 +13,19 @@ func @forward() -> f32 {
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @forward_alloca
|
||||
func @forward_alloca() -> f32 {
|
||||
%0 = alloca() : memref<1024xf32>
|
||||
%c42 = constant 24 : index
|
||||
// CHECK: %[[CST:.*]] = constant 1.0
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %0[%c42] : memref<1024xf32>
|
||||
// CHECK-NOT: load
|
||||
%1 = load %0[%c42] : memref<1024xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @wrong_index
|
||||
func @wrong_index() -> f32 {
|
||||
%0 = alloc() : memref<1024xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user