[MLIR][XLA] Expose parameters of LhloFuseLinalg pass using llvm flags.
Adds flags: "tile-to-parallel-loops-for-linalg-fusion": "Tiles GenericOp consumer to parallel loops before linalg fusion" "tile-sizes-for-linalg-fusion": "Tile sizes by which to tile linalg generic before linalg fusion"), PiperOrigin-RevId: 295955774 Change-Id: Ia0aa12821d19b1710668d3336dc1278e02411ee5
This commit is contained in:
parent
5dab22191d
commit
768717c917
@ -1,32 +1,57 @@
|
|||||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||||
|
// RUN: tf-opt -lhlo-fuse-linalg -tile-sizes-for-linalg-fusion=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
|
||||||
|
// RUN: tf-opt -lhlo-fuse-linalg -tile-to-parallel-loops-for-linalg-fusion %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
|
||||||
|
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||||
dealloc %temp_result : memref<2x2xf32>
|
dealloc %temp_result : memref<6x6xf32>
|
||||||
"xla_lhlo.terminator"() : () -> ()
|
"xla_lhlo.terminator"() : () -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
// CHECK-NOT: linalg.generic
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
// CHECK: loop.for
|
// CHECK-NOT: linalg.generic
|
||||||
// CHECK: loop.for
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
// CHECK-NOT: loop.for
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
// CHECK: linalg.generic
|
// CHECK-NOT: loop.for
|
||||||
// CHECK: addf
|
// CHECK: linalg.generic
|
||||||
// CHECK: linalg.generic
|
// CHECK: addf
|
||||||
// CHECK: mulf
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: mulf
|
||||||
|
|
||||||
|
// TILED-LABEL: func @fusion
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: loop.for {{.*}} step %[[C2]]
|
||||||
|
// TILED: loop.for {{.*}} step %[[C3]]
|
||||||
|
// TILED-NOT: loop.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: addf
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: mulf
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @fusion
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: loop.parallel
|
||||||
|
// PLOOP-NOT: loop.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: addf
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: mulf
|
||||||
|
|
||||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%arg1: memref<100xf32>,
|
%arg1: memref<100xf32>,
|
||||||
@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
// CHECK-NOT: linalg.generic
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
// CHECK: loop.for
|
// CHECK-NOT: linalg.generic
|
||||||
// CHECK: loop.for
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
// CHECK-NOT: loop.for
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
// CHECK: linalg.generic
|
// CHECK-NOT: loop.for
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: subf
|
// CHECK: linalg.generic
|
||||||
// CHECK: linalg.generic
|
// CHECK: subf
|
||||||
// CHECK: exp
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: exp
|
||||||
|
|
||||||
|
// TILED-LABEL: func @fusion_of_three
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: loop.for {{.*}} step %[[C2]]
|
||||||
|
// TILED: loop.for {{.*}} step %[[C3]]
|
||||||
|
// TILED-NOT: loop.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: subf
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: exp
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @fusion_of_three
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: loop.parallel
|
||||||
|
// PLOOP-NOT: loop.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: subf
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: exp
|
||||||
|
@ -22,6 +22,20 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::opt<bool> tile_to_parallel_loops_for_linalg_fusion(
|
||||||
|
"tile-to-parallel-loops-for-linalg-fusion",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::list<unsigned> tile_sizes_for_linalg_fusion(
|
||||||
|
"tile-sizes-for-linalg-fusion",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
||||||
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_lhlo {
|
namespace xla_lhlo {
|
||||||
namespace {
|
namespace {
|
||||||
@ -50,13 +64,16 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
|||||||
OpBuilder b(func);
|
OpBuilder b(func);
|
||||||
OperationFolder folder(func.getContext());
|
OperationFolder folder(func.getContext());
|
||||||
func.walk([&](linalg::GenericOp generic_op) {
|
func.walk([&](linalg::GenericOp generic_op) {
|
||||||
const SmallVector<int64_t, 2> tile_sizes(
|
SmallVector<int64_t, 2> tile_sizes(tile_sizes_for_linalg_fusion.begin(),
|
||||||
generic_op.getNumInputsAndOutputs(), 1);
|
tile_sizes_for_linalg_fusion.end());
|
||||||
|
if (tile_sizes.empty()) {
|
||||||
|
tile_sizes =
|
||||||
|
SmallVector<int64_t, 2>(generic_op.getNumInputsAndOutputs(), 1);
|
||||||
|
}
|
||||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||||
for (const Value result : op.getOutputBuffers()) {
|
for (const Value result : op.getOutputBuffers()) {
|
||||||
if (!func_args.count(result)) continue;
|
if (!func_args.count(result)) continue;
|
||||||
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
|
if (tileGenericOp(op, tile_sizes, &b, &folder)) {
|
||||||
&folder)) {
|
|
||||||
generic_op.erase();
|
generic_op.erase();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -83,6 +100,18 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
|||||||
}
|
}
|
||||||
for (auto* e : erase_set) e->erase();
|
for (auto* e : erase_set) e->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b,
|
||||||
|
OperationFolder* folder) {
|
||||||
|
auto tiled_generic_op =
|
||||||
|
tile_to_parallel_loops_for_linalg_fusion
|
||||||
|
? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes,
|
||||||
|
/*permutation=*/{}, folder)
|
||||||
|
: linalg::tileLinalgOp(*b, op, tile_sizes,
|
||||||
|
/*permutation=*/{}, folder);
|
||||||
|
return tiled_generic_op.hasValue();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user