[MLIR][XLA] Polish LHLO copy operation removal

In this PR, we
- Created a separate file for lhlo-copy-removal pass.
- Created a separate test file with dedicated test cases for lhlo-copy-removal pass.
- Adapted Hlo-To-Lhlo-Legalization tests.

PiperOrigin-RevId: 299315930
Change-Id: Ief1428dc05746aeac4890161efda6173e49ad765
This commit is contained in:
Alexander Belyaev 2020-03-06 03:04:30 -08:00 committed by TensorFlower Gardener
parent 47df0000dc
commit efff893c70
9 changed files with 266 additions and 112 deletions

View File

@ -74,6 +74,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",

View File

@ -237,6 +237,19 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "lhlo_copy_removal",
srcs = ["transforms/lhlo_copy_removal.cc"],
deps = [
":lhlo",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
)
cc_library(
name = "hlo_legalize_to_lhlo",
srcs = ["transforms/hlo_legalize_to_lhlo.cc"],

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure
// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -6,69 +6,48 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_result = "xla_hlo.exp"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @func_op
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]])
return %0 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
%1 = xla_hlo.max %arg0, %arg1 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
%2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
%3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
%3 = xla_hlo.min %arg0, %arg1 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
%4 = xla_hlo.sub %arg1, %3 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[RESULT]])
%5 = xla_hlo.mul %2, %4 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tensor_store %0, %arg2 : memref<f32>
return
}
// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
// CHECK: "xla_lhlo.terminator"() : () -> ()
// -----
// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
@ -78,9 +57,11 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.mul"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %{{.*}})
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
@ -92,7 +73,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -104,7 +85,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exp"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -116,7 +97,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.log"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -131,7 +112,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -145,7 +126,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// CHECK-NEXT: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
@ -158,7 +139,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -195,7 +176,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// CHECK-NEXT: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
@ -207,7 +188,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -219,7 +200,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -243,7 +224,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cos"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.cos"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.cos"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -255,7 +236,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.neg"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.neg"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.neg"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -267,7 +248,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -279,7 +260,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -291,7 +272,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -304,7 +285,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}

View File

@ -0,0 +1,93 @@
// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_not_be_removed
func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_first
func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_second
func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}

View File

@ -273,14 +273,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) {name = "add"} :
// %2 = "xla_hlo.add"(%0, %1) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} :
// %4 = "xla_hlo.mul"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) {name = "fusion"} : () -> ()
// }) : () -> ()
// return
// }
//
@ -290,14 +290,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( {
// %0 = alloc() {temp = true} : memref<2x2xf32>
// %0 = alloc() : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.mul"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// dealloc %0 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) {name = "fusion"} : () -> ()
// }) : () -> ()
// return
// }
// }
@ -305,9 +305,9 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// FuncOp signature conversion example:
//
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32>
// return %1 : tensor<4xf32>
// %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
// }
//
// Transformed function with an extra argument for the result. The types have
@ -316,11 +316,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// func @func_op(%arg0: memref<4xf32>,
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() {temp = true} : memref<4xf32>
// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} :
// %0 = alloc() : memref<4xf32>
// %1 = alloc() : memref<4xf32>
// "xla_lhlo.max"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} :
// "xla_lhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// dealloc %0 : memref<4xf32>
// dealloc %1 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }
@ -473,57 +476,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}
/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
/// argument. All uses of the buffer are replaced with the block argument.
struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
void runOnFunction() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
if (std::any_of(arguments.begin(), arguments.end(),
[&](mlir::BlockArgument arg) {
return copyOp.output() == arg;
}) &&
std::none_of(arguments.begin(), arguments.end(),
[&](mlir::BlockArgument arg) {
return copyOp.operand() == arg;
})) {
mlir::Value operand = copyOp.operand();
mlir::Value output = copyOp.output();
copyOp.erase();
for (auto op : operand.getUsers()) {
if (!mlir::isa<mlir::DeallocOp>(op)) {
op->replaceUsesOfWith(operand, output);
}
}
auto allocOp = operand.getDefiningOp();
if (auto deallocOp =
mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
eraseList.push_back(deallocOp);
eraseList.push_back(allocOp);
}
}
});
for (auto op : eraseList) {
op->erase();
}
};
};
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
return absl::make_unique<RedundantCopiesRemoval>();
}
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
static PassRegistration<RedundantCopiesRemoval> copies_removal_pass(
"lhlo-redundant-copies-removal",
"Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements a pass to remove redundant LHLO copy operations.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace {
// Removes LHLO copy operations that copy from allocated buffers to block
// arguments. All uses of each buffer are replaced with the corresponding block
// argument and the buffer is freed. Note that this pass only works in regions
// with a single block.
struct LhloCopyRemoval : mlir::OperationPass<LhloCopyRemoval> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy
// operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
return;
}
mlir::Value fromOperand = copyOp.operand();
mlir::Value toOperand = copyOp.output();
// If the fromOperand value is a block argument or the toOperand
// value is not a block argument, then ignore this copy operation.
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
return;
}
// The copy operation removal is illegal if there is at least a single use
// of toOperand value that lies between the first use of fromOperand value
// and the copy operation.
auto fromOperandUsers = fromOperand.getUsers();
auto firstUser = *fromOperandUsers.begin();
for (auto op : fromOperandUsers) {
if (op->isBeforeInBlock(firstUser)) firstUser = op;
}
for (auto op : toOperand.getUsers()) {
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
return;
}
}
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
// block arguments.
// Remove the associated alloc operation.
auto allocOp = fromOperand.getDefiningOp();
eraseList.push_back(allocOp);
// Iterate over all uses of the fromOperand to find the associated
// deallocOp (if any).
for (auto op : fromOperandUsers) {
if (isa<mlir::DeallocOp>(op)) {
eraseList.push_back(op);
break;
}
}
// Replace all uses of the fromOperand with the toOperand. This rewires
// all references pointing to the original alloc operation to the new
// target operation in order to safely remove the copy op.
fromOperand.replaceAllUsesWith(toOperand);
copyOp.erase();
});
for (auto op : eraseList) {
op->erase();
}
};
};
} // namespace
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
return absl::make_unique<LhloCopyRemoval>();
}
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -29,6 +29,7 @@ class ModuleOp;
class Operation;
template <typename T>
class OpPassBase;
class Pass;
namespace xla_hlo {
@ -59,11 +60,6 @@ std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. These copies have been created by replacing TensorStoreOp
// with LHLO.CopyOp in HLO to LHLO lowering.
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();
} // namespace xla_hlo
namespace xla_lhlo {
@ -89,6 +85,12 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToGpuPass();
std::unique_ptr<OpPassBase<FuncOp>> createLhloFuseLinalg(
bool use_parallel_loops = false, ArrayRef<unsigned> tile_sizes = {});
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. The block arguments are used instead of all uses of these
// buffers. The buffers are freed. This pass only works in regions that contain
// a single block.
std::unique_ptr<Pass> createLhloCopyRemovalPass();
} // namespace xla_lhlo
} // namespace mlir

View File

@ -148,6 +148,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",

View File

@ -277,7 +277,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Remove unnecessary Lhlo copies.
pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass());
// Transform lhlo operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations. This will yield a single tiled loop nest where