[MLIR] Clean-up LHLO linalg fusion again.

PiperOrigin-RevId: 275197163
Change-Id: I92644a0f90f84a0ab17c1404477e0bf00de8cbe2
This commit is contained in:
Alexander Belyaev 2019-10-17 00:08:22 -07:00 committed by TensorFlower Gardener
parent cd7458379f
commit 72e9dd3b5d
2 changed files with 17 additions and 14 deletions
tensorflow/compiler/mlir/xla

View File

@ -2,19 +2,19 @@
#map0 = (d0, d1) -> (d0, d1)
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
func @fusion(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>, %arg3: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
linalg.generic #pointwise_2d_trait %arg1, %arg2, %0 {
^bb0(%block_arg0: f32, %block_arg1: f32, %block_arg2: f32):
%1 = addf %block_arg0, %block_arg1 : f32
func @fusion(%input0: memref<2x2xf32>, %input1: memref<2x2xf32>, %input2: memref<2x2xf32>, %result: memref<2x2xf32>) {
%temp_result = alloc() {temp = true} : memref<2x2xf32>
linalg.generic #pointwise_2d_trait %input1, %input2, %temp_result {
^bb0(%input1_in: f32, %input2_in: f32, %temp_result_in: f32):
%1 = addf %input1_in, %input2_in : f32
linalg.yield %1 : f32
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
linalg.generic #pointwise_2d_trait %0, %arg0, %arg3 {
^bb0(%block_arg0: f32, %block_arg1: f32, %block_arg2: f32):
%1 = mulf %block_arg0, %block_arg1 : f32
linalg.generic #pointwise_2d_trait %temp_result, %input0, %result {
^bb0(%temp_result_in: f32, %input0_in: f32, %result_in: f32):
%1 = mulf %temp_result_in, %input0_in : f32
linalg.yield %1 : f32
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
dealloc %0 : memref<2x2xf32>
dealloc %temp_result : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}

View File

@ -30,7 +30,13 @@ using linalg::LinalgOp;
struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
void runOnFunction() override {
auto func = getFunction();
OperationFolder state(func.getContext());
// TODO(pifon): Remove assumption that the function has a single block.
if (func.getBlocks().size() != 1) {
emitError(func.getLoc(), "The function needs to have a single block.");
signalPassFailure();
return;
}
// The fusion in Linalg is currently possible only when the consumer op is
// tiled. In order to greedily fuse the ops, we have to start from the tiled
@ -40,6 +46,7 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
for (auto func_arg : func.getArguments()) {
func_args.insert(func_arg);
}
OperationFolder state(func.getContext());
func.walk([&](linalg::GenericOp generic_op) {
const SmallVector<int64_t, 2> tile_sizes(
generic_op.getNumInputsAndOutputs(), 1);
@ -56,10 +63,6 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
// Fuse producers of tiled linalg ops.
llvm::SmallDenseSet<Operation*> erase_set;
SmallVector<Operation*, 8> linalg_ops;
// TODO(pifon): Remove assumption that the function has a single block.
if (func.getBlocks().size() != 1) {
emitError(func.getLoc(), "The function needs to have a single block.");
}
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);