Add shape constraints to CHLO->HLO lowering.
PiperOrigin-RevId: 316459663 Change-Id: Ifff45b67a039c5a8e7cf8fa1bedf187c33900091
This commit is contained in:
parent
eabae7b8e9
commit
83b4360a3f
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
@ -14,14 +14,18 @@ func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
|
||||
func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
@ -31,14 +35,18 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
|
||||
func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
return %0 : tensor<?x?xcomplex<f32>>
|
||||
}
|
||||
@ -50,12 +58,16 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xi1>
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Note that binary elementwise tests are run with chlo legalization enabled
|
||||
// (unlike the rest), since this is the primary use case for such ops and
|
||||
// verification of shapes and broadcasts is desired.
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op legalizations.
|
||||
@ -24,13 +24,8 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// patterns unambiguous and more interesting (once broadcastable trait is
|
||||
// fixed upstream).
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0: tensor<1x2xi32>
|
||||
}
|
||||
@ -39,26 +34,26 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream
|
||||
// broadcastable bug is fixed (helps make the CHECK matching unambiguous)
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0: tensor<4x4x4x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_dynamic
|
||||
func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
|
||||
// CHECK-DAG: %[[CSTR_LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[CSTR_RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[CSTR_LHS_SHAPE]], %[[CSTR_RHS_SHAPE]]
|
||||
// CHECK-NEXT: shape.assuming %[[WITNESS:.+]]
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
}
|
||||
@ -80,21 +75,21 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
|
||||
// CHECK-LABEL: func @div_unranked
|
||||
func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK: tf.Div
|
||||
// CHECK-NEXT: tf.Div
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maximum
|
||||
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @minimum
|
||||
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -200,26 +195,25 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
|
||||
// CHECK-LABEL: func @equal_dynamic
|
||||
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]]
|
||||
// CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor<?xi1>) {
|
||||
// CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_broadcast
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
@ -281,26 +275,25 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
|
||||
// CHECK-LABEL: func @broadcast_greater
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_dynamic
|
||||
func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]]
|
||||
// CHECK-NEXT: shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE1:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE1]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
@ -112,6 +112,19 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
|
||||
// Compute result shape.
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Insert a constraint on the shapes being broadcastable and insert all
|
||||
// future code into an assuming block reliant on the constraint.
|
||||
Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
|
||||
auto broadcastable_cstr =
|
||||
rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
|
||||
auto assuming_op = rewriter.create<shape::AssumingOp>(
|
||||
loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.createBlock(&assuming_op.doRegion());
|
||||
|
||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||
Value result_extents =
|
||||
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
|
||||
@ -140,8 +153,10 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
|
||||
|
||||
// And generate the final non-broadcasted binary op.
|
||||
rewriter.replaceOp(op, {Adaptor::CreateOp(op, result_type, broadcasted_lhs,
|
||||
broadcasted_rhs, rewriter)});
|
||||
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
|
||||
broadcasted_rhs, rewriter);
|
||||
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user