diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 90c60a85ba2..22e665cc8ce 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:lhlo_copy_removal", "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 7352024fb81..5a05a0f60f3 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -237,6 +237,19 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lhlo_copy_removal", + srcs = ["transforms/lhlo_copy_removal.cc"], + deps = [ + ":lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + ], + alwayslink = 1, +) + cc_library( name = "hlo_legalize_to_lhlo", srcs = ["transforms/hlo_legalize_to_lhlo.cc"], diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 31833b11de6..605e806831c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure +// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -6,69 +6,48 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_result = "xla_hlo.exp"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// CHECK-LABEL: func @func_op -func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) - %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) - return %0 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () -} - -// ----- - // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - %1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> + %1 = xla_hlo.max %arg0, %arg1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) - %2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32> + %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) - %3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> + %3 = xla_hlo.min %arg0, %arg1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) - %4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32> + %4 = xla_hlo.sub %arg1, %3 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) - %5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[RESULT]]) + %5 = xla_hlo.mul %2, %4 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> + // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () + // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> return %5 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } // ----- -// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store -func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: tensor, %arg2: memref) { - %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor, tensor) -> tensor - tensor_store %0, %arg2 : memref - return -} -// CHECK: (%[[NEW_ARG0:.*]]: memref, %[[NEW_ARG1:.*]]: memref, %[[RESULT:.*]]: memref) -// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref -// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref, memref, memref) -> () -// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref, memref) -> () -// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref -// CHECK: "xla_lhlo.terminator"() : () -> () - -// ----- - // CHECK-LABEL: func @fusion func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { + // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> @@ -78,9 +57,11 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "xla_hlo.mul"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %{{.*}}) + // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> + // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> () } @@ -92,7 +73,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -104,7 +85,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.exp"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -116,7 +97,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.log"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -131,7 +112,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -145,7 +126,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // CHECK-NEXT: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } @@ -158,7 +139,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -195,7 +176,7 @@ func @dyn_broadcast(%operand: memref) { func @iota(%result: memref<10xi32>) { %tensor_result = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // CHECK-NEXT: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } @@ -207,7 +188,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -219,7 +200,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -243,7 +224,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.cos"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.cos"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.cos"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -255,7 +236,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.neg"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.neg"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.neg"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -267,7 +248,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -279,7 +260,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -291,7 +272,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -304,7 +285,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir new file mode 100644 index 00000000000..35546594ccb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir @@ -0,0 +1,93 @@ +// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: func @remove_simple +func @remove_simple(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @remove_without_dealloc +func @remove_without_dealloc(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @replace_dependency +func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @keep_copies +func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_not_be_removed +func @must_not_be_removed(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_first +func @must_be_removed_first(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_second +func @must_be_removed_second(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 502b79399a9..f5ebc91d78f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -273,14 +273,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // "xla_lhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> -// %2 = "xla_hlo.add"(%0, %1) {name = "add"} : +// %2 = "xla_hlo.add"(%0, %1) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32> -// %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} : +// %4 = "xla_hlo.mul"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () -// }) {name = "fusion"} : () -> () +// }) : () -> () // return // } // @@ -290,14 +290,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { // "xla_lhlo.fusion"() ( { -// %0 = alloc() {temp = true} : memref<2x2xf32> +// %0 = alloc() : memref<2x2xf32> // "xla_lhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // "xla_lhlo.mul"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () -// }) {name = "fusion"} : () -> () +// }) : () -> () // return // } // } @@ -305,9 +305,9 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // FuncOp signature conversion example: // // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> -// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32> -// return %1 : tensor<4xf32> +// %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, +// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // } // // Transformed function with an extra argument for the result. The types have @@ -316,11 +316,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // func @func_op(%arg0: memref<4xf32>, // %arg1: memref<4xf32>, // %arg2: memref<4xf32>) { -// %0 = alloc() {temp = true} : memref<4xf32> -// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} : +// %0 = alloc() : memref<4xf32> +// %1 = alloc() : memref<4xf32> +// "xla_lhlo.max"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} : +// "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// dealloc %0 : memref<4xf32> // dealloc %1 : memref<4xf32> // "xla_lhlo.terminator"() : () -> () // } @@ -473,57 +476,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, // clang-format on } -/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block -/// argument. All uses of the buffer are replaced with the block argument. -struct RedundantCopiesRemoval : mlir::FunctionPass { - void runOnFunction() override { - llvm::SmallVector eraseList; - getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) { - auto arguments = copyOp.getOperation()->getBlock()->getArguments(); - if (std::any_of(arguments.begin(), arguments.end(), - [&](mlir::BlockArgument arg) { - return copyOp.output() == arg; - }) && - std::none_of(arguments.begin(), arguments.end(), - [&](mlir::BlockArgument arg) { - return copyOp.operand() == arg; - })) { - mlir::Value operand = copyOp.operand(); - mlir::Value output = copyOp.output(); - copyOp.erase(); - for (auto op : operand.getUsers()) { - if (!mlir::isa(op)) { - op->replaceUsesOfWith(operand, output); - } - } - auto allocOp = operand.getDefiningOp(); - if (auto deallocOp = - mlir::dyn_cast(*allocOp->getUsers().begin())) { - eraseList.push_back(deallocOp); - eraseList.push_back(allocOp); - } - } - }); - for (auto op : eraseList) { - op->erase(); - } - }; -}; - std::unique_ptr> createLegalizeToLhloPass() { return absl::make_unique(); } -std::unique_ptr> createLhloCopyRemovalPass() { - return absl::make_unique(); -} - static PassRegistration legalize_pass( "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); -static PassRegistration copies_removal_pass( - "lhlo-redundant-copies-removal", - "Legalize from HLO dialect to LHLO dialect"); - } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc new file mode 100644 index 00000000000..86125126390 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements a pass to remove redundant LHLO copy operations. + +#include "absl/memory/memory.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Removes LHLO copy operations that copy from allocated buffers to block +// arguments. All uses of each buffer are replaced with the corresponding block +// argument and the buffer is freed. Note that this pass only works in regions +// with a single block. +struct LhloCopyRemoval : mlir::OperationPass { + void runOnOperation() override { + llvm::SmallVector eraseList; + auto operation = getOperation(); + operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + // If this region contains more than one block, then ignore this copy + // operation. + if (copyOp.getParentRegion()->getBlocks().size() > 1) { + return; + } + + mlir::Value fromOperand = copyOp.operand(); + mlir::Value toOperand = copyOp.output(); + + // If the fromOperand value is a block argument or the toOperand + // value is not a block argument, then ignore this copy operation. + if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) { + return; + } + + // The copy operation removal is illegal if there is at least a single use + // of toOperand value that lies between the first use of fromOperand value + // and the copy operation. + auto fromOperandUsers = fromOperand.getUsers(); + auto firstUser = *fromOperandUsers.begin(); + for (auto op : fromOperandUsers) { + if (op->isBeforeInBlock(firstUser)) firstUser = op; + } + for (auto op : toOperand.getUsers()) { + if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) { + return; + } + } + + // TODO(DFKI): Use live variable analysis to solve aliasing issues among + // block arguments. + + // Remove the associated alloc operation. + auto allocOp = fromOperand.getDefiningOp(); + eraseList.push_back(allocOp); + + // Iterate over all uses of the fromOperand to find the associated + // deallocOp (if any). + for (auto op : fromOperandUsers) { + if (isa(op)) { + eraseList.push_back(op); + break; + } + } + + // Replace all uses of the fromOperand with the toOperand. This rewires + // all references pointing to the original alloc operation to the new + // target operation in order to safely remove the copy op. + fromOperand.replaceAllUsesWith(toOperand); + copyOp.erase(); + }); + for (auto op : eraseList) { + op->erase(); + } + }; +}; + +} // namespace + +std::unique_ptr createLhloCopyRemovalPass() { + return absl::make_unique(); +} + +static PassRegistration copy_removal_pass( + "lhlo-copy-removal", "Removes redundant LHLO copy operations"); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 8c0ed08fb66..1fb5f6a1328 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -29,6 +29,7 @@ class ModuleOp; class Operation; template class OpPassBase; +class Pass; namespace xla_hlo { @@ -59,11 +60,6 @@ std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); -// Removes unnecessary LHLO copies which copy from the allocated buffers to the -// block arguments. These copies have been created by replacing TensorStoreOp -// with LHLO.CopyOp in HLO to LHLO lowering. -std::unique_ptr> createLhloCopyRemovalPass(); - } // namespace xla_hlo namespace xla_lhlo { @@ -89,6 +85,12 @@ std::unique_ptr> createLegalizeToGpuPass(); std::unique_ptr> createLhloFuseLinalg( bool use_parallel_loops = false, ArrayRef tile_sizes = {}); +// Removes unnecessary LHLO copies which copy from the allocated buffers to the +// block arguments. The block arguments are used instead of all uses of these +// buffers. The buffers are freed. This pass only works in regions that contain +// a single block. +std::unique_ptr createLhloCopyRemovalPass(); + } // namespace xla_lhlo } // namespace mlir diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index afceefdeae6..02311c72ee1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -148,6 +148,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:lhlo_copy_removal", "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index e922fe64958..151d82fd2a1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -277,7 +277,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary Lhlo copies. - pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass()); // Transform lhlo operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. This will yield a single tiled loop nest where