[MLIR_GPU] Generalize StoreForwarding and DeadTempBufferRemoval passes.
Now, instead of supporting AllocOp, they use side-effects interface to detect allocations. PiperOrigin-RevId: 336160536 Change-Id: I5356503ed17e7379bb27160b3b6ff1b240e978da
This commit is contained in:
parent
2b31ba7d0a
commit
abaca545db
@ -191,6 +191,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:SCFTransforms",
|
"@llvm-project//mlir:SCFTransforms",
|
||||||
|
"@llvm-project//mlir:SideEffects",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
@ -277,6 +278,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:SideEffects",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -294,6 +296,7 @@ tf_cc_binary(
|
|||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:SideEffects",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
@ -54,6 +55,21 @@ struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename EffectTy>
|
||||||
|
bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) {
|
||||||
|
auto mem_effects_interface =
|
||||||
|
mlir::dyn_cast_or_null<mlir::MemoryEffectOpInterface>(op);
|
||||||
|
if (!mem_effects_interface) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
|
||||||
|
mem_effects_interface.getEffects(effects);
|
||||||
|
return llvm::any_of(effects,
|
||||||
|
[op](const mlir::MemoryEffects::EffectInstance& effect) {
|
||||||
|
return mlir::isa<EffectTy>(effect.getEffect());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
||||||
mlir::StoreOp findStore(mlir::Operation* op,
|
mlir::StoreOp findStore(mlir::Operation* op,
|
||||||
std::function<bool(mlir::StoreOp)> matches) {
|
std::function<bool(mlir::StoreOp)> matches) {
|
||||||
@ -87,10 +103,9 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
|||||||
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
|
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
|
||||||
defOp = subviewOp.source().getDefiningOp();
|
defOp = subviewOp.source().getDefiningOp();
|
||||||
}
|
}
|
||||||
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
|
return HasEffectsOnValue<mlir::MemoryEffects::Allocate>(memref, defOp)
|
||||||
return allocOp.getOperation();
|
? defOp
|
||||||
}
|
: nullptr;
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieves AllocOp from the cache or actually looks for it.
|
// Retrieves AllocOp from the cache or actually looks for it.
|
||||||
@ -101,7 +116,7 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
|||||||
if (allocOpIt != memrefToAllocOp->end()) {
|
if (allocOpIt != memrefToAllocOp->end()) {
|
||||||
return allocOpIt->second;
|
return allocOpIt->second;
|
||||||
}
|
}
|
||||||
auto allocOp = SearchAllocOp(memref);
|
mlir::Operation* allocOp = SearchAllocOp(memref);
|
||||||
memrefToAllocOp->insert({memref, allocOp});
|
memrefToAllocOp->insert({memref, allocOp});
|
||||||
return allocOp;
|
return allocOp;
|
||||||
}
|
}
|
||||||
@ -169,13 +184,18 @@ struct DeadTempBufferRemovalPass
|
|||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
llvm::SmallVector<mlir::Operation*, 8> dead_ops;
|
llvm::SmallVector<mlir::Operation*, 8> dead_ops;
|
||||||
getFunction().walk([&](mlir::AllocOp allocOp) {
|
getFunction().walk([&](mlir::Operation* op) {
|
||||||
if (!operationConsideredDead(allocOp)) {
|
if (op->getNumResults() != 1 ||
|
||||||
|
!HasEffectsOnValue<mlir::MemoryEffects::Allocate>(op->getResult(0),
|
||||||
|
op)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!operationConsideredDead(op)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(herhut): There should be a generic helper for this.
|
// TODO(herhut): There should be a generic helper for this.
|
||||||
recursiveErase(allocOp, &dead_ops);
|
recursiveErase(op, &dead_ops);
|
||||||
});
|
});
|
||||||
for (auto op : dead_ops) {
|
for (auto op : dead_ops) {
|
||||||
op->erase();
|
op->erase();
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: xla-gpu-opt %s | FileCheck %s
|
// RUN: xla-gpu-opt %s --print-ir-after-all | FileCheck %s
|
||||||
HloModule AddMultiply
|
HloModule AddMultiply
|
||||||
|
|
||||||
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
|
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
|
||||||
|
@ -6,6 +6,18 @@ func @dead() {
|
|||||||
%0 = alloc() : memref<42xi32>
|
%0 = alloc() : memref<42xi32>
|
||||||
%c0 = constant 0 : i32
|
%c0 = constant 0 : i32
|
||||||
%c12 = constant 12 : index
|
%c12 = constant 12 : index
|
||||||
|
// CHECK-NOT: store
|
||||||
|
store %c0, %0[%c12] : memref<42xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dead_alloca
|
||||||
|
func @dead_alloca() {
|
||||||
|
// CHECK-NOT: alloca
|
||||||
|
%0 = alloc() : memref<42xi32>
|
||||||
|
%c0 = constant 0 : i32
|
||||||
|
%c12 = constant 12 : index
|
||||||
|
// CHECK-NOT: store
|
||||||
store %c0, %0[%c12] : memref<42xi32>
|
store %c0, %0[%c12] : memref<42xi32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,19 @@ func @forward() -> f32 {
|
|||||||
return %1 : f32
|
return %1 : f32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @forward_alloca
|
||||||
|
func @forward_alloca() -> f32 {
|
||||||
|
%0 = alloca() : memref<1024xf32>
|
||||||
|
%c42 = constant 24 : index
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1.0
|
||||||
|
%c1 = constant 1.0 : f32
|
||||||
|
store %c1, %0[%c42] : memref<1024xf32>
|
||||||
|
// CHECK-NOT: load
|
||||||
|
%1 = load %0[%c42] : memref<1024xf32>
|
||||||
|
// CHECK: return %[[CST]]
|
||||||
|
return %1 : f32
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @wrong_index
|
// CHECK-LABEL: @wrong_index
|
||||||
func @wrong_index() -> f32 {
|
func @wrong_index() -> f32 {
|
||||||
%0 = alloc() : memref<1024xf32>
|
%0 = alloc() : memref<1024xf32>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user