Support escaping result memrefs in lhlo_fuse_linalg.
So far, we have identified the root computation to fuse into by it writing into a function argument. Now writing into a buffer that is returned also qualifies. PiperOrigin-RevId: 316677942 Change-Id: I7c3912419606555946c9111d12c4086d086d9456
This commit is contained in:
parent
2537f3d413
commit
770251a700
|
@ -377,6 +377,7 @@ cc_library(
|
|||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
|
@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
|
@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
// -----
|
||||
|
||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
%arg2: memref<100x10xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<100x10xf32>
|
||||
%0 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 1 : i64,
|
||||
args_out = 1 : i64,
|
||||
|
@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
linalg.yield %arg3 : f32
|
||||
}: memref<100xf32>, memref<100x10xf32>
|
||||
%1 = alloc() {temp = true} : memref<100x10xf32>
|
||||
%1 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 2 : i64,
|
||||
args_out = 1 : i64,
|
||||
|
@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
// PLOOP: linalg.generic
|
||||
// PLOOP: exp
|
||||
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
|
||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
|
@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
linalg.yield %out : f32
|
||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion_4d
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
|
@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
%result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
return %result : memref<6x6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
|
||||
// TILED-LABEL: func @fusion
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED: scf.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: addf
|
||||
// TILED: linalg.generic
|
||||
// TILED: mulf
|
||||
|
||||
// PLOOP-LABEL: func @fusion
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
|||
// The fusion in Linalg is currently possible only when the consumer op is
|
||||
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
||||
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
||||
// function.
|
||||
llvm::SmallDenseSet<Value> func_args;
|
||||
// function or are returned in case of escaping allocations.
|
||||
llvm::SmallDenseSet<Value> result_buffers;
|
||||
for (auto func_arg : func.getArguments()) {
|
||||
func_args.insert(func_arg);
|
||||
result_buffers.insert(func_arg);
|
||||
}
|
||||
for (auto& block : func.getBlocks()) {
|
||||
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
||||
if (!returnOp) continue;
|
||||
for (auto operand : returnOp.getOperands()) {
|
||||
result_buffers.insert(operand);
|
||||
}
|
||||
}
|
||||
MLIRContext* ctx = func.getContext();
|
||||
OpBuilder b(func);
|
||||
|
@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
|||
}
|
||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||
for (const Value result : op.getOutputBuffers()) {
|
||||
if (!func_args.count(result)) continue;
|
||||
if (!result_buffers.count(result)) continue;
|
||||
if (tileGenericOp(op, tile_sizes, &b)) {
|
||||
generic_op.erase();
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue