Support escaping result memrefs in lhlo_fuse_linalg.

So far, we have identified the root computation to fuse into by it writing into a function argument. Now writing into a buffer that is returned also qualifies.

PiperOrigin-RevId: 316677942
Change-Id: I7c3912419606555946c9111d12c4086d086d9456
This commit is contained in:
Stephan Herhut 2020-06-16 07:34:49 -07:00 committed by TensorFlower Gardener
parent 2537f3d413
commit 770251a700
3 changed files with 82 additions and 16 deletions

View File

@ -377,6 +377,7 @@ cc_library(
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,

View File

@ -1,13 +1,12 @@
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() {temp = true} : memref<6x6xf32>
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
"xla_lhlo.terminator"() : () -> ()
return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%0 = alloc() {temp = true} : memref<100x10xf32>
%0 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32>
%1 = alloc() {temp = true} : memref<100x10xf32>
%1 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// PLOOP: linalg.generic
// PLOOP: exp
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
%temp_result = alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
dealloc %temp_result : memref<6x6x6x6xf32>
"xla_lhlo.terminator"() : () -> ()
return
}
// CHECK-LABEL: func @fusion_4d
// CHECK: %[[C1:.*]] = constant 1
@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
%result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32>
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
// The fusion in Linalg is currently possible only when the consumer op is
// tiled. In order to greedily fuse the ops, we have to start from the tiled
// root linalg ops, i.e. linalg ops that write to output buffers of the
// function.
llvm::SmallDenseSet<Value> func_args;
// function or are returned in case of escaping allocations.
llvm::SmallDenseSet<Value> result_buffers;
for (auto func_arg : func.getArguments()) {
func_args.insert(func_arg);
result_buffers.insert(func_arg);
}
for (auto& block : func.getBlocks()) {
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
if (!returnOp) continue;
for (auto operand : returnOp.getOperands()) {
result_buffers.insert(operand);
}
}
MLIRContext* ctx = func.getContext();
OpBuilder b(func);
@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
}
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value result : op.getOutputBuffers()) {
if (!func_args.count(result)) continue;
if (!result_buffers.count(result)) continue;
if (tileGenericOp(op, tile_sizes, &b)) {
generic_op.erase();
return;