[MLIR][XLA] Expose parameters of LhloFuseLinalg pass using llvm flags.
Adds flags: "tile-to-parallel-loops-for-linalg-fusion": "Tiles GenericOp consumer to parallel loops before linalg fusion" "tile-sizes-for-linalg-fusion": "Tile sizes by which to tile linalg generic before linalg fusion"), PiperOrigin-RevId: 295955774 Change-Id: Ia0aa12821d19b1710668d3336dc1278e02411ee5
This commit is contained in:
parent
5dab22191d
commit
768717c917
@ -1,32 +1,57 @@
|
||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||
// RUN: tf-opt -lhlo-fuse-linalg -tile-sizes-for-linalg-fusion=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
|
||||
// RUN: tf-opt -lhlo-fuse-linalg -tile-to-parallel-loops-for-linalg-fusion %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
|
||||
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
dealloc %temp_result : memref<2x2xf32>
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for
|
||||
// CHECK: loop.for
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
|
||||
// TILED-LABEL: func @fusion
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: loop.for {{.*}} step %[[C2]]
|
||||
// TILED: loop.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: loop.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: addf
|
||||
// TILED: linalg.generic
|
||||
// TILED: mulf
|
||||
|
||||
// PLOOP-LABEL: func @fusion
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: loop.parallel
|
||||
// PLOOP-NOT: loop.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for
|
||||
// CHECK: loop.for
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
|
||||
// TILED-LABEL: func @fusion_of_three
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: loop.for {{.*}} step %[[C2]]
|
||||
// TILED: loop.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: loop.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: linalg.generic
|
||||
// TILED: subf
|
||||
// TILED: linalg.generic
|
||||
// TILED: exp
|
||||
|
||||
// PLOOP-LABEL: func @fusion_of_three
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: loop.parallel
|
||||
// PLOOP-NOT: loop.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: subf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: exp
|
||||
|
@ -22,6 +22,20 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> tile_to_parallel_loops_for_linalg_fusion(
|
||||
"tile-to-parallel-loops-for-linalg-fusion",
|
||||
llvm::cl::desc(
|
||||
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<unsigned> tile_sizes_for_linalg_fusion(
|
||||
"tile-sizes-for-linalg-fusion",
|
||||
llvm::cl::desc(
|
||||
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
@ -50,13 +64,16 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
OpBuilder b(func);
|
||||
OperationFolder folder(func.getContext());
|
||||
func.walk([&](linalg::GenericOp generic_op) {
|
||||
const SmallVector<int64_t, 2> tile_sizes(
|
||||
generic_op.getNumInputsAndOutputs(), 1);
|
||||
SmallVector<int64_t, 2> tile_sizes(tile_sizes_for_linalg_fusion.begin(),
|
||||
tile_sizes_for_linalg_fusion.end());
|
||||
if (tile_sizes.empty()) {
|
||||
tile_sizes =
|
||||
SmallVector<int64_t, 2>(generic_op.getNumInputsAndOutputs(), 1);
|
||||
}
|
||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||
for (const Value result : op.getOutputBuffers()) {
|
||||
if (!func_args.count(result)) continue;
|
||||
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
|
||||
&folder)) {
|
||||
if (tileGenericOp(op, tile_sizes, &b, &folder)) {
|
||||
generic_op.erase();
|
||||
return;
|
||||
}
|
||||
@ -83,6 +100,18 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
}
|
||||
for (auto* e : erase_set) e->erase();
|
||||
}
|
||||
|
||||
private:
|
||||
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b,
|
||||
OperationFolder* folder) {
|
||||
auto tiled_generic_op =
|
||||
tile_to_parallel_loops_for_linalg_fusion
|
||||
? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes,
|
||||
/*permutation=*/{}, folder)
|
||||
: linalg::tileLinalgOp(*b, op, tile_sizes,
|
||||
/*permutation=*/{}, folder);
|
||||
return tiled_generic_op.hasValue();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user