[MLIR_GPU] Generalize StoreForwarding and DeadTempBufferRemoval passes.

Now, instead of supporting AllocOp, they use side-effects interface to
detect allocations.

PiperOrigin-RevId: 336160536
Change-Id: I5356503ed17e7379bb27160b3b6ff1b240e978da
This commit is contained in:
Alexander Belyaev 2020-10-08 14:04:26 -07:00 committed by TensorFlower Gardener
parent 2b31ba7d0a
commit abaca545db
5 changed files with 57 additions and 9 deletions

View File

@ -191,6 +191,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
@ -277,6 +278,7 @@ tf_cc_binary(
"//tensorflow/core:lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:Support",
],
)
@ -294,6 +296,7 @@ tf_cc_binary(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:Support",
],
)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
@ -54,6 +55,21 @@ struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
}
};
template <typename EffectTy>
bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) {
auto mem_effects_interface =
mlir::dyn_cast_or_null<mlir::MemoryEffectOpInterface>(op);
if (!mem_effects_interface) {
return false;
}
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
mem_effects_interface.getEffects(effects);
return llvm::any_of(effects,
[op](const mlir::MemoryEffects::EffectInstance& effect) {
return mlir::isa<EffectTy>(effect.getEffect());
});
}
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
mlir::StoreOp findStore(mlir::Operation* op,
std::function<bool(mlir::StoreOp)> matches) {
@ -87,10 +103,9 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
defOp = subviewOp.source().getDefiningOp();
}
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
return allocOp.getOperation();
}
return nullptr;
return HasEffectsOnValue<mlir::MemoryEffects::Allocate>(memref, defOp)
? defOp
: nullptr;
}
// Retrieves AllocOp from the cache or actually looks for it.
@ -101,7 +116,7 @@ struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
if (allocOpIt != memrefToAllocOp->end()) {
return allocOpIt->second;
}
auto allocOp = SearchAllocOp(memref);
mlir::Operation* allocOp = SearchAllocOp(memref);
memrefToAllocOp->insert({memref, allocOp});
return allocOp;
}
@ -169,13 +184,18 @@ struct DeadTempBufferRemovalPass
void runOnFunction() override {
llvm::SmallVector<mlir::Operation*, 8> dead_ops;
getFunction().walk([&](mlir::AllocOp allocOp) {
if (!operationConsideredDead(allocOp)) {
getFunction().walk([&](mlir::Operation* op) {
if (op->getNumResults() != 1 ||
!HasEffectsOnValue<mlir::MemoryEffects::Allocate>(op->getResult(0),
op)) {
return;
}
if (!operationConsideredDead(op)) {
return;
}
// TODO(herhut): There should be a generic helper for this.
recursiveErase(allocOp, &dead_ops);
recursiveErase(op, &dead_ops);
});
for (auto op : dead_ops) {
op->erase();

View File

@ -1,4 +1,4 @@
// RUN: xla-gpu-opt %s | FileCheck %s
// RUN: xla-gpu-opt %s --print-ir-after-all | FileCheck %s
HloModule AddMultiply
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {

View File

@ -6,6 +6,18 @@ func @dead() {
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c12 = constant 12 : index
// CHECK-NOT: store
store %c0, %0[%c12] : memref<42xi32>
return
}
// CHECK-LABEL: @dead_alloca
func @dead_alloca() {
// CHECK-NOT: alloca
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c12 = constant 12 : index
// CHECK-NOT: store
store %c0, %0[%c12] : memref<42xi32>
return
}

View File

@ -13,6 +13,19 @@ func @forward() -> f32 {
return %1 : f32
}
// CHECK-LABEL: @forward_alloca
func @forward_alloca() -> f32 {
%0 = alloca() : memref<1024xf32>
%c42 = constant 24 : index
// CHECK: %[[CST:.*]] = constant 1.0
%c1 = constant 1.0 : f32
store %c1, %0[%c42] : memref<1024xf32>
// CHECK-NOT: load
%1 = load %0[%c42] : memref<1024xf32>
// CHECK: return %[[CST]]
return %1 : f32
}
// CHECK-LABEL: @wrong_index
func @wrong_index() -> f32 {
%0 = alloc() : memref<1024xf32>