diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 7f9e8c19780..a9ffc116392 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -1,32 +1,57 @@ -// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s +// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always +// RUN: tf-opt -lhlo-fuse-linalg -tile-sizes-for-linalg-fusion=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure +// RUN: tf-opt -lhlo-fuse-linalg -tile-to-parallel-loops-for-linalg-fusion %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure + #map0 = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} -func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, - %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { - %temp_result = alloc() {temp = true} : memref<2x2xf32> +func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, + %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { + %temp_result = alloc() {temp = true} : memref<6x6xf32> linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> - dealloc %temp_result : memref<2x2xf32> + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + dealloc %temp_result : memref<6x6xf32> "xla_lhlo.terminator"() : () -> () } // CHECK-LABEL: func @fusion -// CHECK-NOT: linalg.generic -// CHECK: loop.for -// CHECK: loop.for -// CHECK-NOT: loop.for -// CHECK: linalg.generic -// CHECK: addf -// CHECK: linalg.generic -// CHECK: mulf +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf func @fusion_of_three(%arg0: memref<100x10xf32>, %arg1: memref<100xf32>, @@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, return } // CHECK-LABEL: func @fusion -// CHECK-NOT: linalg.generic -// CHECK: loop.for -// CHECK: loop.for -// CHECK-NOT: loop.for -// CHECK: linalg.generic -// CHECK: linalg.generic -// CHECK: subf -// CHECK: linalg.generic -// CHECK: exp +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: linalg.generic +// CHECK: subf +// CHECK: linalg.generic +// CHECK: exp + +// TILED-LABEL: func @fusion_of_three +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: linalg.generic +// TILED: subf +// TILED: linalg.generic +// TILED: exp + +// PLOOP-LABEL: func @fusion_of_three +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: linalg.generic +// PLOOP: subf +// PLOOP: linalg.generic +// PLOOP: exp diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index b5e33fb0663..6b2b548550a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -22,6 +22,20 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Transforms/FoldUtils.h" // TF:llvm-project +// NOLINTNEXTLINE +static llvm::cl::opt tile_to_parallel_loops_for_linalg_fusion( + "tile-to-parallel-loops-for-linalg-fusion", + llvm::cl::desc( + "Tiles GenericOp consumer to parallel loops before linalg fusion"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::list tile_sizes_for_linalg_fusion( + "tile-sizes-for-linalg-fusion", + llvm::cl::desc( + "Tile sizes by which to tile linalg generic before linalg fusion"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); + namespace mlir { namespace xla_lhlo { namespace { @@ -50,13 +64,16 @@ struct LhloFuseLinalg : public FunctionPass { OpBuilder b(func); OperationFolder folder(func.getContext()); func.walk([&](linalg::GenericOp generic_op) { - const SmallVector tile_sizes( - generic_op.getNumInputsAndOutputs(), 1); + SmallVector tile_sizes(tile_sizes_for_linalg_fusion.begin(), + tile_sizes_for_linalg_fusion.end()); + if (tile_sizes.empty()) { + tile_sizes = + SmallVector(generic_op.getNumInputsAndOutputs(), 1); + } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { if (!func_args.count(result)) continue; - if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{}, - &folder)) { + if (tileGenericOp(op, tile_sizes, &b, &folder)) { generic_op.erase(); return; } @@ -83,6 +100,18 @@ struct LhloFuseLinalg : public FunctionPass { } for (auto* e : erase_set) e->erase(); } + + private: + bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b, + OperationFolder* folder) { + auto tiled_generic_op = + tile_to_parallel_loops_for_linalg_fusion + ? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes, + /*permutation=*/{}, folder) + : linalg::tileLinalgOp(*b, op, tile_sizes, + /*permutation=*/{}, folder); + return tiled_generic_op.hasValue(); + } }; } // namespace