Extend lowering from hlo to lhlo to also support buffer allocation with escaping result buffers. This is now a flag to the pass (defaults to the current preallocation behavior).

PiperOrigin-RevId: 316660810
Change-Id: I89e46b494d09acf2dbe14b300ee5b9df431ab09c
This commit is contained in:
Stephan Herhut 2020-06-16 05:20:07 -07:00 committed by TensorFlower Gardener
parent a5ebf37c1d
commit fb7fe20ec7
3 changed files with 203 additions and 166 deletions

View File

@ -1,12 +1,13 @@
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s // RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
// CHECK-LABEL: func @attrs // BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand) %tensor_result = "xla_hlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -16,13 +17,16 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32> return %arg0 : tensor<4xf32>
} }
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () // PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK-NEXT: return // PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// ----- // -----
// CHECK-LABEL: func @func_op_long // BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32> %2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
@ -31,89 +35,91 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32> %5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32> return %5 : tensor<4xf32>
} }
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> // ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) // BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> //  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> // PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// CHECK-NEXT: return // PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// ----- // -----
// CHECK-LABEL: func @fusion // BOTH-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) // BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) // BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // BOTH-NEXT: return
"xla_lhlo.terminator"() : () -> () return
} }
// ----- // -----
// CHECK-LABEL: func @copy // BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand) %tensor_result = "xla_hlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @exp // BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand) %tensor_result = "xla_hlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @log // BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand) %tensor_result = "xla_hlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @select // BOTH-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_pred = tensor_load %pred : memref<2x2xi1> %tensor_pred = tensor_load %pred : memref<2x2xi1>
@ -121,34 +127,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @compare // BOTH-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"} {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1> tensor_store %tensor_result, %result : memref<2x2xi1>
return return
} }
// ----- // -----
// CHECK-LABEL: func @broadcast // BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32> %tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>} {broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32> : (tensor<5xf32>) -> tensor<10x5xf32>
// CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32> tensor_store %tensor_result, %result : memref<10x5xf32>
return return
} }
@ -157,55 +163,55 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64> func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> // BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK-LABEL: func @dyn_broadcast // BOTH-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) { func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>) // BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64> %shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = call @external_func() // BOTH: %[[SHAPE:.*]] = call @external_func()
// CHECK: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index // BOTH: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> // BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index // BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index // BOTH: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> // BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index // BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) // BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[C0_:.*]] = constant 0 : index // BOTH: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index // BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index // BOTH: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> // BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index // BOTH: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32> // BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index // BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] // BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index // BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index // BOTH: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> // BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index // BOTH: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32> // BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index // BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast // BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0> // BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> () // BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to // Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument. // a copy into the pre-allocated argument.
@ -214,7 +220,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// ----- // -----
// CHECK-LABEL: func @complex // BOTH-LABEL: func @complex
func @complex(%real: memref<2x2xf32>, func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>, %imag: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) { %result: memref<2x2xcomplex<f32>>) {
@ -222,164 +228,164 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) %tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// CHECK: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>> tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return return
} }
// ----- // -----
// CHECK-LABEL: func @real // BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand) %tensor_result = "xla_hlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.real"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @imag // BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand) %tensor_result = "xla_hlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @iota // BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) { func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"() %tensor_result = "xla_hlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32> {iota_dimension = 0 : i64} : () -> tensor<10xi32>
// CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32> tensor_store %tensor_result, %result : memref<10xi32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @abs // BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand) %tensor_result = "xla_hlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @ceil // BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand) %tensor_result = "xla_hlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @convert // BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand) %tensor_result = "xla_hlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// CHECK-NOT: tensor_store // BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @cos // BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand) %tensor_result = "xla_hlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @neg // BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand) %tensor_result = "xla_hlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @rsqrt // BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand) %tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @sign // BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand) %tensor_result = "xla_hlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @sqrt // BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand) %tensor_result = "xla_hlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @tanh // BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand) %tensor_result = "xla_hlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// ----- // -----
// CHECK-LABEL: func @remainder // BOTH-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -387,76 +393,79 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// ----- // -----
// Dynamic shape binary element-wise operation. // Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn // BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs) %result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index // BOTH: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index // BOTH: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index // BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () // BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
// ----- // -----
// Dynamic shape unary element-wise operation. // Dynamic shape unary element-wise operation.
// CHECK-LABEL: func @tanh_dyn // BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) { func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0) %result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index // BOTH: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index // BOTH: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index // BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () // BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
// ----- // -----
// CHECK-LABEL: func @dot // BOTH-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0) %dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32> return %dot : tensor<1024x1024xf32>
} }
// ----- // -----
// CHECK-LABEL: func @conv // BOTH-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index %c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[ // BOTH-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]> // BOTH-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]> // BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) { %out = "xla_hlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {

View File

@ -368,6 +368,15 @@ class HloToLhloTensorStoreOpConverter
struct HloLegalizeToLhlo struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> { : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
public:
HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
void runOnOperation() override { void runOnOperation() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto& context = getContext(); auto& context = getContext();
@ -398,10 +407,28 @@ struct HloLegalizeToLhlo
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
&converter, &patterns); &converter, &patterns);
if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns);
} else {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}
return WalkResult( return WalkResult(
applyPartialConversion(func, target, patterns, &converter)); applyPartialConversion(func, target, patterns, &converter));
}); });
} }
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
}; };
} // namespace } // namespace
@ -446,14 +473,11 @@ void populateHLOToLHLOConversionPattern(
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter); >(context, bufferAssignment, converter);
// clang-format on // clang-format on
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(context, bufferAssignment,
converter, patterns);
} }
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() { std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
return absl::make_unique<HloLegalizeToLhlo>(); bool results_escape_function) {
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
} }
static PassRegistration<HloLegalizeToLhlo> legalize_pass( static PassRegistration<HloLegalizeToLhlo> legalize_pass(

View File

@ -59,9 +59,13 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
/// Lowers from HLO dialect to Standard dialect. /// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary. /// buffers if necessary. If `results_escape_functions` is set to true,
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(); /// allocated buffers for function results will be returned and escape the
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
// Lowers from HLO dialect to Linalg dialect. // Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();