Merge pull request #37514 from wyzero:fix-wrong-tiling-size-dim
PiperOrigin-RevId: 300755212 Change-Id: I5638d80a04d1d5bc937ae9723b1a3075cbc4ced5
This commit is contained in:
commit
8b6cad7adc
@ -125,3 +125,55 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||||||
// PLOOP: subf
|
// PLOOP: subf
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: exp
|
// PLOOP: exp
|
||||||
|
|
||||||
|
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||||
|
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||||
|
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||||
|
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
|
||||||
|
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
||||||
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||||
|
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result {
|
||||||
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||||
|
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||||
|
"xla_lhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @fusion_4d
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-NOT: linalg.generic
|
||||||
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK-NOT: loop.for
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: addf
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: mulf
|
||||||
|
|
||||||
|
// TILED-LABEL: func @fusion_4d
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: loop.for {{.*}} step %[[C2]]
|
||||||
|
// TILED: loop.for {{.*}} step %[[C3]]
|
||||||
|
// TILED-NOT: loop.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: addf
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: mulf
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @fusion_4d
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: loop.parallel
|
||||||
|
// PLOOP-NOT: loop.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: addf
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: mulf
|
||||||
|
@ -63,8 +63,7 @@ class LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
|||||||
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
|
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
|
||||||
tile_sizes_.end());
|
tile_sizes_.end());
|
||||||
if (tile_sizes.empty()) {
|
if (tile_sizes.empty()) {
|
||||||
tile_sizes =
|
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
|
||||||
SmallVector<int64_t, 2>(generic_op.getNumInputsAndOutputs(), 1);
|
|
||||||
}
|
}
|
||||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||||
for (const Value result : op.getOutputBuffers()) {
|
for (const Value result : op.getOutputBuffers()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user