[MLIR][XLA] Polish LHLO copy operation removal
In this PR, we - Created a separate file for lhlo-copy-removal pass. - Created a separate test file with dedicated test cases for lhlo-copy-removal pass. - Adapted Hlo-To-Lhlo-Legalization tests. PiperOrigin-RevId: 299315930 Change-Id: Ief1428dc05746aeac4890161efda6173e49ad765
This commit is contained in:
parent
47df0000dc
commit
efff893c70
@ -74,6 +74,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
|
@ -237,6 +237,19 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_copy_removal",
|
||||
srcs = ["transforms/lhlo_copy_removal.cc"],
|
||||
deps = [
|
||||
":lhlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_legalize_to_lhlo",
|
||||
srcs = ["transforms/hlo_legalize_to_lhlo.cc"],
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure
|
||||
// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
@ -6,69 +6,48 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_result = "xla_hlo.exp"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
// CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_op
|
||||
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]])
|
||||
return %0 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
|
||||
%1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
%1 = xla_hlo.max %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
%2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32>
|
||||
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
%3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
%3 = xla_hlo.min %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
|
||||
%4 = xla_hlo.sub %arg1, %3 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[RESULT]])
|
||||
%5 = xla_hlo.mul %2, %4 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
|
||||
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
|
||||
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
tensor_store %0, %arg2 : memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
|
||||
// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
|
||||
// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
|
||||
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
|
||||
// CHECK: "xla_lhlo.terminator"() : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fusion
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
@ -78,9 +57,11 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.mul"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %{{.*}})
|
||||
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
@ -92,7 +73,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -104,7 +85,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exp"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -116,7 +97,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -131,7 +112,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -145,7 +126,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
||||
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// CHECK-NEXT: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
// CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
tensor_store %tensor_result, %result : memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
@ -158,7 +139,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<10x5xf32>
|
||||
return
|
||||
}
|
||||
@ -195,7 +176,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "xla_hlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// CHECK-NEXT: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
// CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
return
|
||||
}
|
||||
@ -207,7 +188,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -219,7 +200,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -243,7 +224,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.cos"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.cos"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.cos"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -255,7 +236,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.neg"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.neg"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.neg"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -267,7 +248,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -279,7 +260,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -291,7 +272,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -304,7 +285,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
93
tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir
Normal file
93
tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir
Normal file
@ -0,0 +1,93 @@
|
||||
// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @remove_simple
|
||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remove_without_dealloc
|
||||
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @replace_dependency
|
||||
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @keep_copies
|
||||
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_not_be_removed
|
||||
func @must_not_be_removed(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_be_removed_first
|
||||
func @must_be_removed_first(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_be_removed_second
|
||||
func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
@ -273,14 +273,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
// "xla_lhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "xla_hlo.add"(%0, %1) {name = "add"} :
|
||||
// %2 = "xla_hlo.add"(%0, %1) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// %3 = tensor_load %arg0 : memref<2x2xf32>
|
||||
// %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} :
|
||||
// %4 = "xla_hlo.mul"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// }) {name = "fusion"} : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
//
|
||||
@ -290,14 +290,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
// %arg2: memref<2x2xf32>,
|
||||
// %arg3: memref<2x2xf32>) {
|
||||
// "xla_lhlo.fusion"() ( {
|
||||
// %0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// %0 = alloc() : memref<2x2xf32>
|
||||
// "xla_lhlo.add"(%arg1, %arg2, %0) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// "xla_lhlo.mul"(%0, %arg0, %arg3) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// dealloc %0 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// }) {name = "fusion"} : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
@ -305,9 +305,9 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
// FuncOp signature conversion example:
|
||||
//
|
||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
|
||||
// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32>
|
||||
// return %1 : tensor<4xf32>
|
||||
// %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
|
||||
// }
|
||||
//
|
||||
// Transformed function with an extra argument for the result. The types have
|
||||
@ -316,11 +316,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
// func @func_op(%arg0: memref<4xf32>,
|
||||
// %arg1: memref<4xf32>,
|
||||
// %arg2: memref<4xf32>) {
|
||||
// %0 = alloc() {temp = true} : memref<4xf32>
|
||||
// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} :
|
||||
// %0 = alloc() : memref<4xf32>
|
||||
// %1 = alloc() : memref<4xf32>
|
||||
// "xla_lhlo.max"(%arg0, %arg1, %0) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} :
|
||||
// "xla_lhlo.add"(%arg0, %0, %1) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// dealloc %0 : memref<4xf32>
|
||||
// dealloc %1 : memref<4xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// }
|
||||
@ -473,57 +476,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
|
||||
/// argument. All uses of the buffer are replaced with the block argument.
|
||||
struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
|
||||
void runOnFunction() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
|
||||
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
|
||||
if (std::any_of(arguments.begin(), arguments.end(),
|
||||
[&](mlir::BlockArgument arg) {
|
||||
return copyOp.output() == arg;
|
||||
}) &&
|
||||
std::none_of(arguments.begin(), arguments.end(),
|
||||
[&](mlir::BlockArgument arg) {
|
||||
return copyOp.operand() == arg;
|
||||
})) {
|
||||
mlir::Value operand = copyOp.operand();
|
||||
mlir::Value output = copyOp.output();
|
||||
copyOp.erase();
|
||||
for (auto op : operand.getUsers()) {
|
||||
if (!mlir::isa<mlir::DeallocOp>(op)) {
|
||||
op->replaceUsesOfWith(operand, output);
|
||||
}
|
||||
}
|
||||
auto allocOp = operand.getDefiningOp();
|
||||
if (auto deallocOp =
|
||||
mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
|
||||
eraseList.push_back(deallocOp);
|
||||
eraseList.push_back(allocOp);
|
||||
}
|
||||
}
|
||||
});
|
||||
for (auto op : eraseList) {
|
||||
op->erase();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
|
||||
return absl::make_unique<HloLegalizeToLhlo>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
|
||||
return absl::make_unique<RedundantCopiesRemoval>();
|
||||
}
|
||||
|
||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
static PassRegistration<RedundantCopiesRemoval> copies_removal_pass(
|
||||
"lhlo-redundant-copies-removal",
|
||||
"Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
105
tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc
Normal file
105
tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements a pass to remove redundant LHLO copy operations.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
|
||||
// Removes LHLO copy operations that copy from allocated buffers to block
|
||||
// arguments. All uses of each buffer are replaced with the corresponding block
|
||||
// argument and the buffer is freed. Note that this pass only works in regions
|
||||
// with a single block.
|
||||
struct LhloCopyRemoval : mlir::OperationPass<LhloCopyRemoval> {
|
||||
void runOnOperation() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
auto operation = getOperation();
|
||||
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
|
||||
// If this region contains more than one block, then ignore this copy
|
||||
// operation.
|
||||
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
mlir::Value fromOperand = copyOp.operand();
|
||||
mlir::Value toOperand = copyOp.output();
|
||||
|
||||
// If the fromOperand value is a block argument or the toOperand
|
||||
// value is not a block argument, then ignore this copy operation.
|
||||
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The copy operation removal is illegal if there is at least a single use
|
||||
// of toOperand value that lies between the first use of fromOperand value
|
||||
// and the copy operation.
|
||||
auto fromOperandUsers = fromOperand.getUsers();
|
||||
auto firstUser = *fromOperandUsers.begin();
|
||||
for (auto op : fromOperandUsers) {
|
||||
if (op->isBeforeInBlock(firstUser)) firstUser = op;
|
||||
}
|
||||
for (auto op : toOperand.getUsers()) {
|
||||
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
|
||||
// block arguments.
|
||||
|
||||
// Remove the associated alloc operation.
|
||||
auto allocOp = fromOperand.getDefiningOp();
|
||||
eraseList.push_back(allocOp);
|
||||
|
||||
// Iterate over all uses of the fromOperand to find the associated
|
||||
// deallocOp (if any).
|
||||
for (auto op : fromOperandUsers) {
|
||||
if (isa<mlir::DeallocOp>(op)) {
|
||||
eraseList.push_back(op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace all uses of the fromOperand with the toOperand. This rewires
|
||||
// all references pointing to the original alloc operation to the new
|
||||
// target operation in order to safely remove the copy op.
|
||||
fromOperand.replaceAllUsesWith(toOperand);
|
||||
copyOp.erase();
|
||||
});
|
||||
for (auto op : eraseList) {
|
||||
op->erase();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
||||
return absl::make_unique<LhloCopyRemoval>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
|
||||
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace mlir
|
@ -29,6 +29,7 @@ class ModuleOp;
|
||||
class Operation;
|
||||
template <typename T>
|
||||
class OpPassBase;
|
||||
class Pass;
|
||||
|
||||
namespace xla_hlo {
|
||||
|
||||
@ -59,11 +60,6 @@ std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();
|
||||
// Lowers from HLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();
|
||||
|
||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
||||
// block arguments. These copies have been created by replacing TensorStoreOp
|
||||
// with LHLO.CopyOp in HLO to LHLO lowering.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
@ -89,6 +85,12 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToGpuPass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLhloFuseLinalg(
|
||||
bool use_parallel_loops = false, ArrayRef<unsigned> tile_sizes = {});
|
||||
|
||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
||||
// block arguments. The block arguments are used instead of all uses of these
|
||||
// buffers. The buffers are freed. This pass only works in regions that contain
|
||||
// a single block.
|
||||
std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -148,6 +148,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
|
@ -277,7 +277,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
|
||||
// Next, we can strip the outer fusion operation.
|
||||
pm.addPass(absl::make_unique<FusionOpRemover>());
|
||||
// Remove unnecessary Lhlo copies.
|
||||
pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
|
||||
pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
// Transform lhlo operations to LinAlg.
|
||||
pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
|
||||
// Fuse linalg operations. This will yield a single tiled loop nest where
|
||||
|
Loading…
Reference in New Issue
Block a user