[MLIR] Clean-up LHLO linalg fusion again.
PiperOrigin-RevId: 275197163 Change-Id: I92644a0f90f84a0ab17c1404477e0bf00de8cbe2
This commit is contained in:
parent
cd7458379f
commit
72e9dd3b5d
tensorflow/compiler/mlir/xla
@ -2,19 +2,19 @@
|
||||
|
||||
#map0 = (d0, d1) -> (d0, d1)
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
|
||||
func @fusion(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>, %arg3: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
linalg.generic #pointwise_2d_trait %arg1, %arg2, %0 {
|
||||
^bb0(%block_arg0: f32, %block_arg1: f32, %block_arg2: f32):
|
||||
%1 = addf %block_arg0, %block_arg1 : f32
|
||||
func @fusion(%input0: memref<2x2xf32>, %input1: memref<2x2xf32>, %input2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
||||
linalg.generic #pointwise_2d_trait %input1, %input2, %temp_result {
|
||||
^bb0(%input1_in: f32, %input2_in: f32, %temp_result_in: f32):
|
||||
%1 = addf %input1_in, %input2_in : f32
|
||||
linalg.yield %1 : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
linalg.generic #pointwise_2d_trait %0, %arg0, %arg3 {
|
||||
^bb0(%block_arg0: f32, %block_arg1: f32, %block_arg2: f32):
|
||||
%1 = mulf %block_arg0, %block_arg1 : f32
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %input0, %result {
|
||||
^bb0(%temp_result_in: f32, %input0_in: f32, %result_in: f32):
|
||||
%1 = mulf %temp_result_in, %input0_in : f32
|
||||
linalg.yield %1 : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
dealloc %temp_result : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,13 @@ using linalg::LinalgOp;
|
||||
struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
OperationFolder state(func.getContext());
|
||||
|
||||
// TODO(pifon): Remove assumption that the function has a single block.
|
||||
if (func.getBlocks().size() != 1) {
|
||||
emitError(func.getLoc(), "The function needs to have a single block.");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// The fusion in Linalg is currently possible only when the consumer op is
|
||||
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
||||
@ -40,6 +46,7 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
for (auto func_arg : func.getArguments()) {
|
||||
func_args.insert(func_arg);
|
||||
}
|
||||
OperationFolder state(func.getContext());
|
||||
func.walk([&](linalg::GenericOp generic_op) {
|
||||
const SmallVector<int64_t, 2> tile_sizes(
|
||||
generic_op.getNumInputsAndOutputs(), 1);
|
||||
@ -56,10 +63,6 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
// Fuse producers of tiled linalg ops.
|
||||
llvm::SmallDenseSet<Operation*> erase_set;
|
||||
SmallVector<Operation*, 8> linalg_ops;
|
||||
// TODO(pifon): Remove assumption that the function has a single block.
|
||||
if (func.getBlocks().size() != 1) {
|
||||
emitError(func.getLoc(), "The function needs to have a single block.");
|
||||
}
|
||||
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
||||
|
Loading…
Reference in New Issue
Block a user