Support escaping result memrefs in lhlo_fuse_linalg.
So far, we have identified the root computation to fuse into by it writing into a function argument. Now writing into a buffer that is returned also qualifies. PiperOrigin-RevId: 316677942 Change-Id: I7c3912419606555946c9111d12c4086d086d9456
This commit is contained in:
parent
2537f3d413
commit
770251a700
|
@ -377,6 +377,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
"@llvm-project//mlir:LinalgTransforms",
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:TransformUtils",
|
"@llvm-project//mlir:TransformUtils",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
|
||||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED
|
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
|
||||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP
|
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||||
|
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||||
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
|
@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||||
dealloc %temp_result : memref<6x6xf32>
|
dealloc %temp_result : memref<6x6xf32>
|
||||||
"xla_lhlo.terminator"() : () -> ()
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
// CHECK: %[[C1:.*]] = constant 1
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: mulf
|
// PLOOP: mulf
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%arg1: memref<100xf32>,
|
%arg1: memref<100xf32>,
|
||||||
%arg2: memref<100x10xf32>) {
|
%arg2: memref<100x10xf32>) {
|
||||||
%0 = alloc() {temp = true} : memref<100x10xf32>
|
%0 = alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
args_in = 1 : i64,
|
args_in = 1 : i64,
|
||||||
args_out = 1 : i64,
|
args_out = 1 : i64,
|
||||||
|
@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
linalg.yield %arg3 : f32
|
linalg.yield %arg3 : f32
|
||||||
}: memref<100xf32>, memref<100x10xf32>
|
}: memref<100xf32>, memref<100x10xf32>
|
||||||
%1 = alloc() {temp = true} : memref<100x10xf32>
|
%1 = alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
args_in = 2 : i64,
|
args_in = 2 : i64,
|
||||||
args_out = 1 : i64,
|
args_out = 1 : i64,
|
||||||
|
@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: exp
|
// PLOOP: exp
|
||||||
|
|
||||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
// -----
|
||||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
|
||||||
|
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||||
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
|
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||||
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
|
@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||||
"xla_lhlo.terminator"() : () -> ()
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion_4d
|
// CHECK-LABEL: func @fusion_4d
|
||||||
// CHECK: %[[C1:.*]] = constant 1
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
// PLOOP: addf
|
// PLOOP: addf
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: mulf
|
// PLOOP: mulf
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||||
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
|
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||||
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
|
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||||
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||||
|
%result = alloc() : memref<6x6xf32>
|
||||||
|
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||||
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||||
|
dealloc %temp_result : memref<6x6xf32>
|
||||||
|
return %result : memref<6x6xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fusion
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-NOT: linalg.generic
|
||||||
|
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK-NOT: scf.for
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: addf
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: mulf
|
||||||
|
|
||||||
|
// TILED-LABEL: func @fusion
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: scf.for {{.*}} step %[[C2]]
|
||||||
|
// TILED: scf.for {{.*}} step %[[C3]]
|
||||||
|
// TILED-NOT: scf.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: addf
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: mulf
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @fusion
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: scf.parallel
|
||||||
|
// PLOOP-NOT: scf.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: addf
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: mulf
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
|
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
|
@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||||
// The fusion in Linalg is currently possible only when the consumer op is
|
// The fusion in Linalg is currently possible only when the consumer op is
|
||||||
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
||||||
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
||||||
// function.
|
// function or are returned in case of escaping allocations.
|
||||||
llvm::SmallDenseSet<Value> func_args;
|
llvm::SmallDenseSet<Value> result_buffers;
|
||||||
for (auto func_arg : func.getArguments()) {
|
for (auto func_arg : func.getArguments()) {
|
||||||
func_args.insert(func_arg);
|
result_buffers.insert(func_arg);
|
||||||
|
}
|
||||||
|
for (auto& block : func.getBlocks()) {
|
||||||
|
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
||||||
|
if (!returnOp) continue;
|
||||||
|
for (auto operand : returnOp.getOperands()) {
|
||||||
|
result_buffers.insert(operand);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MLIRContext* ctx = func.getContext();
|
MLIRContext* ctx = func.getContext();
|
||||||
OpBuilder b(func);
|
OpBuilder b(func);
|
||||||
|
@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||||
}
|
}
|
||||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||||
for (const Value result : op.getOutputBuffers()) {
|
for (const Value result : op.getOutputBuffers()) {
|
||||||
if (!func_args.count(result)) continue;
|
if (!result_buffers.count(result)) continue;
|
||||||
if (tileGenericOp(op, tile_sizes, &b)) {
|
if (tileGenericOp(op, tile_sizes, &b)) {
|
||||||
generic_op.erase();
|
generic_op.erase();
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue