From 770251a7008c1d89e8141b37191e74008e01253d Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Tue, 16 Jun 2020 07:34:49 -0700 Subject: [PATCH] Support escaping result memrefs in lhlo_fuse_linalg. So far, we have identified the root computation to fuse into by it writing into a function argument. Now writing into a buffer that is returned also qualifies. PiperOrigin-RevId: 316677942 Change-Id: I7c3912419606555946c9111d12c4086d086d9456 --- tensorflow/compiler/mlir/xla/BUILD | 1 + .../mlir/xla/tests/lhlo-fuse-linalg.mlir | 81 ++++++++++++++++--- .../mlir/xla/transforms/lhlo_fuse_linalg.cc | 16 +++- 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 8f0f000b26a..43458aab2d3 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -377,6 +377,7 @@ cc_library( "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 063487c00d8..b04c97f42d7 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -1,13 +1,12 @@ -// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always -// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED -// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP - +// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always +// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED +// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { - %temp_result = alloc() {temp = true} : memref<6x6xf32> + %temp_result = alloc() : memref<6x6xf32> linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 @@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, linalg.yield %out : f32 } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> dealloc %temp_result : memref<6x6xf32> - "xla_lhlo.terminator"() : () -> () + return } // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 @@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP: linalg.generic // PLOOP: mulf +// ----- + func @fusion_of_three(%arg0: memref<100x10xf32>, %arg1: memref<100xf32>, %arg2: memref<100x10xf32>) { - %0 = alloc() {temp = true} : memref<100x10xf32> + %0 = alloc() : memref<100x10xf32> linalg.generic { args_in = 1 : i64, args_out = 1 : i64, @@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 }: memref<100xf32>, memref<100x10xf32> - %1 = alloc() {temp = true} : memref<100x10xf32> + %1 = alloc() : memref<100x10xf32> linalg.generic { args_in = 2 : i64, args_out = 1 : i64, @@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP: linalg.generic // PLOOP: exp -#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { - %temp_result = alloc() {temp = true} : memref<6x6x6x6xf32> + %temp_result = alloc() : memref<6x6x6x6xf32> linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 @@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 linalg.yield %out : f32 } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> dealloc %temp_result : memref<6x6x6x6xf32> - "xla_lhlo.terminator"() : () -> () + return } // CHECK-LABEL: func @fusion_4d // CHECK: %[[C1:.*]] = constant 1 @@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // PLOOP: addf // PLOOP: linalg.generic // PLOOP: mulf + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, + %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { + %temp_result = alloc() : memref<6x6xf32> + linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + %result = alloc() : memref<6x6xf32> + linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + dealloc %temp_result : memref<6x6xf32> + return %result : memref<6x6xf32> +} + +// CHECK-LABEL: func @fusion +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index ddbb672c70a..e16ab571b4d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/transforms/passes.h" @@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper { // The fusion in Linalg is currently possible only when the consumer op is // tiled. In order to greedily fuse the ops, we have to start from the tiled // root linalg ops, i.e. linalg ops that write to output buffers of the - // function. - llvm::SmallDenseSet func_args; + // function or are returned in case of escaping allocations. + llvm::SmallDenseSet result_buffers; for (auto func_arg : func.getArguments()) { - func_args.insert(func_arg); + result_buffers.insert(func_arg); + } + for (auto& block : func.getBlocks()) { + auto returnOp = mlir::dyn_cast(block.getTerminator()); + if (!returnOp) continue; + for (auto operand : returnOp.getOperands()) { + result_buffers.insert(operand); + } } MLIRContext* ctx = func.getContext(); OpBuilder b(func); @@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper { } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { - if (!func_args.count(result)) continue; + if (!result_buffers.count(result)) continue; if (tileGenericOp(op, tile_sizes, &b)) { generic_op.erase(); return;