From 72c14949101c3ca53cab5389db779e61d86ba459 Mon Sep 17 00:00:00 2001 From: wyzhao Date: Thu, 12 Mar 2020 01:07:39 +0800 Subject: [PATCH] [MLIR/XLA] fix wrong tilingSize dimension in lhlo-fuse-linalg pass --- .../mlir/xla/tests/lhlo-fuse-linalg.mlir | 52 +++++++++++++++++++ .../mlir/xla/transforms/lhlo_fuse_linalg.cc | 2 +- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 7f7e37ebe66..0a48cbd372f 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -125,3 +125,55 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP: subf // PLOOP: linalg.generic // PLOOP: exp + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, + %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { + %temp_result = alloc() {temp = true} : memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + dealloc %temp_result : memref<6x6x6x6xf32> + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @fusion_4d +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion_4d +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion_4d +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index 8f34034d6d3..6253670d464 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -64,7 +64,7 @@ class LhloFuseLinalg : public FunctionPass { tile_sizes_.end()); if (tile_sizes.empty()) { tile_sizes = - SmallVector(generic_op.getNumInputsAndOutputs(), 1); + SmallVector(generic_op.getNumLoops(), 1); } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) {