diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 456d0c9fab5..1a67ab2c6f0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -191,6 +191,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -277,6 +278,7 @@ tf_cc_binary( "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) @@ -294,6 +296,7 @@ tf_cc_binary( "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc index 7f4810e3f54..f0997701d73 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -54,6 +55,21 @@ struct FusionOpRemoverPass : FusionOpRemoverPassBase { } }; +template +bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) { + auto mem_effects_interface = + mlir::dyn_cast_or_null(op); + if (!mem_effects_interface) { + return false; + } + llvm::SmallVector effects; + mem_effects_interface.getEffects(effects); + return llvm::any_of(effects, + [op](const mlir::MemoryEffects::EffectInstance& effect) { + return mlir::isa(effect.getEffect()); + }); +} + struct StoreForwardingPass : StoreForwardingPassBase { mlir::StoreOp findStore(mlir::Operation* op, std::function matches) { @@ -87,10 +103,9 @@ struct StoreForwardingPass : StoreForwardingPassBase { while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { defOp = subviewOp.source().getDefiningOp(); } - if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { - return allocOp.getOperation(); - } - return nullptr; + return HasEffectsOnValue(memref, defOp) + ? defOp + : nullptr; } // Retrieves AllocOp from the cache or actually looks for it. @@ -101,7 +116,7 @@ struct StoreForwardingPass : StoreForwardingPassBase { if (allocOpIt != memrefToAllocOp->end()) { return allocOpIt->second; } - auto allocOp = SearchAllocOp(memref); + mlir::Operation* allocOp = SearchAllocOp(memref); memrefToAllocOp->insert({memref, allocOp}); return allocOp; } @@ -169,13 +184,18 @@ struct DeadTempBufferRemovalPass void runOnFunction() override { llvm::SmallVector dead_ops; - getFunction().walk([&](mlir::AllocOp allocOp) { - if (!operationConsideredDead(allocOp)) { + getFunction().walk([&](mlir::Operation* op) { + if (op->getNumResults() != 1 || + !HasEffectsOnValue(op->getResult(0), + op)) { + return; + } + if (!operationConsideredDead(op)) { return; } // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp, &dead_ops); + recursiveErase(op, &dead_ops); }); for (auto op : dead_ops) { op->erase(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo index 2603b925c76..20a6e2aa710 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -1,4 +1,4 @@ -// RUN: xla-gpu-opt %s | FileCheck %s +// RUN: xla-gpu-opt %s --print-ir-after-all | FileCheck %s HloModule AddMultiply ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir index 15329d6181a..58132f4ea45 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir @@ -6,6 +6,18 @@ func @dead() { %0 = alloc() : memref<42xi32> %c0 = constant 0 : i32 %c12 = constant 12 : index + // CHECK-NOT: store + store %c0, %0[%c12] : memref<42xi32> + return +} + +// CHECK-LABEL: @dead_alloca +func @dead_alloca() { + // CHECK-NOT: alloca + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + // CHECK-NOT: store store %c0, %0[%c12] : memref<42xi32> return } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir index 263ef290bdf..8b993bb56a5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir @@ -13,6 +13,19 @@ func @forward() -> f32 { return %1 : f32 } +// CHECK-LABEL: @forward_alloca +func @forward_alloca() -> f32 { + %0 = alloca() : memref<1024xf32> + %c42 = constant 24 : index + // CHECK: %[[CST:.*]] = constant 1.0 + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK: return %[[CST]] + return %1 : f32 +} + // CHECK-LABEL: @wrong_index func @wrong_index() -> f32 { %0 = alloc() : memref<1024xf32>