Remove implicit broadcasting from xla_hlo binary elementwise ops.

* Migrates legalize-tf conversions to either:
  * convert through the chlo.broadcast_* ops (majority)
  * have special case broadcasting for non-supported or non-optimal broadcast modes
* This was done one by one, qc'ing each and many bugs/inefficiencies/ambiguous broadcasting modes were corrected.
* Looks like it may be missing a rule for legalizing complex types (will check on Monday).
* I considered splitting this up, but it was actually pretty important to make the ops more strict to flush out all cases (best done as an atomic change).
* Stricter conversions fixed a number of cases where shapes were dropping to unranked or unknown (and needn't be).
* With this change, most of the binary ops and many of the resulting tf2xla expansions correctly support dynamic shapes via the shape dialect.
  * I verified this with the small set of IREE test cases which support dynamic shapes and will expand coverage once this lands.
* This is some test fallout outside of the xla directory that I will fixup on Monday.

PiperOrigin-RevId: 312316083
Change-Id: I6d246d80cddb84f2dfd62817c7166f53c1f6cdec
This commit is contained in:
Stella Laurenzo 2020-05-19 11:17:02 -07:00 committed by TensorFlower Gardener
parent 273617ad91
commit b7735095de
23 changed files with 1056 additions and 1884 deletions

View File

@ -2,17 +2,17 @@
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
return %0 : tensor<?x?x?x?xi32>
}
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0 : tensor<4x4x4x4xi32>
}
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32>
}
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
}
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
return %2 : tensor<2x3xf16>
}
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %3 : tensor<1xi32>
}
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
}
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %3 : tensor<4x8xf32>

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {

View File

@ -18,12 +18,16 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
// Check that two values can be broadcasted together
@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_DivOp, TF_DivOp],
[HLO_ShiftLeftOp, TF_LeftShiftOp],
[HLO_MaxOp, TF_MaximumOp],
[HLO_MinOp, TF_MinimumOp],
[HLO_MulOp, TF_MulOp],
[HLO_PowOp, TF_PowOp],
[HLO_SubOp, TF_SubOp],
[HLO_Atan2Op, TF_Atan2Op],
[HLO_RemOp, TF_ModOp]] in
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op],
[HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp],
[HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp],
[HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp],
[HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp],
[HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp],
[HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp],
[HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp],
[HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op],
[HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in {
def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp],
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp],
[HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in {
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp],
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in {
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
[HLO_OrOp, TF_BitwiseOrOp],
[HLO_XorOp, TF_BitwiseXorOp]] in
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
[HLO_OrOp, TF_LogicalOrOp]] in
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>;
def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim),
//===----------------------------------------------------------------------===//
// Compare op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>;
}
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>;
}

View File

@ -185,6 +185,16 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
// BroadcastCompareOp (has custom type inference due to different result type).
//===----------------------------------------------------------------------===//
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction);
}
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,

View File

@ -360,6 +360,11 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
HLO_ComparisonDirectionAttr:$comparison_direction
);
let results = (outs HLO_PredTensor);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
>];
}
#endif // CHLO_OPS

View File

@ -1401,89 +1401,25 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
namespace {
// Gets the resulting type from a broadcast between two types.
static Type GetBroadcastType(Builder* builder, Type x, Type y,
Type element_type,
DenseIntElementsAttr broadcast_dimensions) {
// Updates the element type of a (presumed) tensor type 'x', returning either
// a permuted UnrankedTensorType or RankedTensorType.
static Type UpdateResultElementType(Builder* builder, Type x,
Type element_type) {
auto x_ranked = x.dyn_cast<RankedTensorType>();
auto y_ranked = y.dyn_cast<RankedTensorType>();
if (!x_ranked || !y_ranked) {
if (!x_ranked) {
return UnrankedTensorType::get(element_type);
}
auto shape_x = x_ranked.getShape();
auto shape_y = y_ranked.getShape();
if (shape_x.size() == shape_y.size()) {
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
for (int i = 0; i < shape_x.size(); i++) {
auto x_val = shape_x[i];
auto y_val = shape_y[i];
if (x_val == -1 || y_val == -1) {
out_shape[i] = -1;
} else {
out_shape[i] = std::max(x_val, y_val);
}
}
return RankedTensorType::get(out_shape, element_type);
}
// Return unranked tensor for invalid broadcast dimensions.
if (!broadcast_dimensions) return UnrankedTensorType::get(element_type);
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
shape_large.end());
// Update according to the broadcast dimensions.
for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) {
auto old_value = out_shape[index_pair.value().getSExtValue()];
auto new_value = shape_small[index_pair.index()];
if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
out_shape[index_pair.value().getSExtValue()] = new_value;
}
}
return RankedTensorType::get(out_shape, element_type);
return RankedTensorType::get(shape_x, element_type);
}
} // namespace
#define BINARY_BUILDER(Op) \
void Op::build(OpBuilder& builder, OperationState& result, Value left, \
Value right, DenseIntElementsAttr broadcast_dimensions) { \
auto type = GetBroadcastType(&builder, left.getType().cast<ShapedType>(), \
right.getType().cast<ShapedType>(), \
getElementTypeOrSelf(right.getType()), \
broadcast_dimensions); \
return Op::build(builder, result, type, left, right, \
broadcast_dimensions); \
}
BINARY_BUILDER(AddOp);
BINARY_BUILDER(AndOp);
BINARY_BUILDER(Atan2Op);
BINARY_BUILDER(DivOp);
BINARY_BUILDER(MaxOp);
BINARY_BUILDER(MinOp);
BINARY_BUILDER(MulOp);
BINARY_BUILDER(OrOp);
BINARY_BUILDER(PowOp);
BINARY_BUILDER(RemOp);
BINARY_BUILDER(ShiftLeftOp);
BINARY_BUILDER(ShiftRightArithmeticOp);
BINARY_BUILDER(ShiftRightLogicalOp);
BINARY_BUILDER(SubOp);
BINARY_BUILDER(XorOp);
#undef BINARY_BUILDER
template <typename Op, typename ElementType = Type, typename ValType,
typename Convert>
static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
if (!attrs[0] || !attrs[1]) return {};
if (op->broadcast_dimensions().hasValue()) return {};
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
@ -1893,12 +1829,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
auto new_type = GetBroadcastType(&builder, lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction);
Value rhs, StringAttr comparison_direction) {
auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction);
}
#define GET_OP_CLASSES

View File

@ -241,15 +241,9 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
HLO_Tensor:$rhs
);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value left, Value right, "
"DenseIntElementsAttr broadcast_dimensions"
>];
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
@ -270,15 +264,15 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
}
def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp {
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp {
let hasFolder = 1;
}
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
def HLO_ComplexOp: HLO_Op<"complex",
[NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>,
[NoSideEffect, SameOperandsAndResultShape]>,
BASE_HLO_ComplexOp {
let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
@ -289,39 +283,39 @@ def HLO_ComplexOp: HLO_Op<"complex",
}
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp {
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp {
}
def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp {
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp {
}
def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp {
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp {
}
def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp {
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp {
let hasFolder = 1;
}
def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp;
def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp;
def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp {
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp {
let hasFolder = 1;
}
@ -331,11 +325,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
HLO_BinaryElementwiseOp<mnemonic, [Commutative, NoSideEffect]> {
HLO_BinaryElementwiseOp<
mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> {
let arguments = (ins
HLO_PredOrIntTensor:$lhs,
HLO_PredOrIntTensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
HLO_PredOrIntTensor:$rhs
);
}
@ -617,23 +611,18 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
}
def HLO_CompareOp: HLO_Op<"compare",
[NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp {
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
BASE_HLO_CompareOp {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value left, Value right, "
"DenseIntElementsAttr broadcast_dimensions, "
"StringAttr comparison_direction"
>];
let results = (outs HLO_PredTensor);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
"StringAttr comparison_direction"
>];
}

View File

@ -209,7 +209,6 @@ StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CompareOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
/*broadcast_dimensions=*/mlir::DenseIntElementsAttr(),
builder_.getStringAttr(ComparisonDirectionToString(direction)));
return MakeXlaOp(op.getResult());
}

View File

@ -0,0 +1,334 @@
// Note that binary elementwise tests are run with chlo legalization enabled
// (unlike the rest), since this is the primary use case for such ops and
// verification of shapes and broadcasts is desired.
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure
//===----------------------------------------------------------------------===//
// Binary op legalizations.
// Most of these expand from the same pattern. Full semantics are
// verified for tf.Add and pattern application only for the rest.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @add
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32>
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32>
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %1: tensor<2xi32>
}
// CHECK-LABEL: func @broadcast_add
// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check
// patterns unambiguous and more interesting (once broadcastable trait is
// fixed upstream).
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0: tensor<1x2xi32>
}
// CHECK-LABEL: func @broadcast_multi_dim_add
// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream
// broadcastable bug is fixed (helps make the CHECK matching unambiguous)
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0: tensor<4x4x4x4xi32>
}
// CHECK-LABEL: func @add_dynamic
func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @div
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @shift_left
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @div_unranked
func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: tf.Div
%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @maximum
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @minimum
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @mul
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @real_div
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
%0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @sub
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @shift_right
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @shift_right_unsigned
func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
// CHECK-LABEL: func @broadcast_shift_right_unsigned
func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8>
return %0 : tensor<2x4xui8>
}
// CHECK-LABEL: func @and
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.and
%0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @and_unranked
func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> {
// CHECK: tf.LogicalAnd
%0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @or
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.or
%0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @bitwise_or
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: xla_hlo.or
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_and
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: xla_hlo.and
%0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @pow
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: xla_hlo.power
%0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0: tensor<2xf32>
}
//===----------------------------------------------------------------------===//
// Equality op legalizations.
// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are
// verified for tf.Equal and pattern application only for tf.NotEqual
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @equal
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
%0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @equal_dynamic
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @equal_broadcast
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0: tensor<1x2xi1>
}
// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0: tensor<1x2xi1>
}
// CHECK-LABEL: func @equal_incompatible_shape_broadcastable
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @equal_incompatible_shape_dynamic
func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<?xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic
func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<?xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @equal_unranked
func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> {
// CHECK: "tf.Equal"
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @notequal
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
%0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
//===----------------------------------------------------------------------===//
// Compare op legalizations.
// These expand from the same pattern. Full semantics are checked for
// tf.Greater. Others just check that the pattern applied.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @greater
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @broadcast_greater
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0: tensor<1x2xi1>
}
// CHECK-LABEL: func @greater_dynamic
func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> {
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @greater_uranked
func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
// CHECK: "tf.Greater"
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @greater_equal
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
%0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @less
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
%0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @less_equal
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
%0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
return %4 : tensor<4xi32>
}
// Broadcasting is not currently supported.
// TODO(suderman):Future pass should take all broadcasted binary ops and convert
// them to separate broadcast and binary op.
// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {
name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} :
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
%1 = "xla_hlo.multiply"(%0, %arg1) {
name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} :
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
%2 = "xla_hlo.subtract"(%1, %arg1) {
name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} :
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
%3 = "xla_hlo.divide"(%2, %arg1) {
name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} :
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
%4 = "xla_hlo.remainder"(%3, %arg1) {
broadcast_dimensions = dense<1> : tensor<1xi64>} :
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: return %4 : tensor<4x4xf32>
return %4 : tensor<4x4xf32>
}
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s
// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @add_broadcast
func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @sub_broadcast
func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @mul_broadcast
func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32>
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// -----
// CHECK-LABEL: @div_broadcast
func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2)
// CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]])
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// -----
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)

View File

@ -1,225 +1,5 @@
// RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s
// CHECK-LABEL: @addBroadcastRhs
func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @addBroadcastLhs
func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @addBroadcastEqual
func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: @addBroadcastMultidimension
func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x1x4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
return %0 : tensor<1x1x4xf32>
}
// -----
// CHECK-LABEL: @addBroadcastBothArgs
func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32>
return %0 : tensor<3x2x2xf32>
}
// -----
// CHECK-LABEL: @addBroadcastScalar
func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @addUnranked
func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: @atan2BroadcastRhs
func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @divBroadcastRhs
func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @maxBroadcastRhs
func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @minBroadcastRhs
func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @mulBroadcastRhs
func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @powBroadcastRhs
func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @remainderBroadcastRhs
func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftLeftBroadcastRhs
func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs
func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftRightLogicalBroadcastRhs
func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @subBroadcastRhs
func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @andBroadcastRhs
func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @orBroadcastRhs
func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @xorBroadcastRhs
func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
@ -229,63 +9,3 @@ func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @compareBroadcastRhs
func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> {
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%arg0, %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>
return %0 : tensor<1x4xi1>
}
// -----
// CHECK-LABEL: @dynamicCompareBroadcastRhs
func @dynamicCompareBroadcastRhs(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xi1> {
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32>
// CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32>
// CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index
// CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
// -----
// CHECK-LABEL: @dynamicBroadcastAdd
func @dynamicBroadcastAdd(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32>
// CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32>
// CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index
// CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @dynamicBroadcastAddScalar
func @dynamicBroadcastAddScalar(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
// CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor<?x?xf32>
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
// CHECK: HloModule
func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
// -----
// CHECK: HloModule
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
// Same rank degenerate broadcast
// CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0)
// CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
// CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
// CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1)
// CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// Broadcast up rank
// CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
// CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
// CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
// Broadcast up rank + degenerate broadcast
// CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
// CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
// CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
// CHECK: ROOT
// CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
return %2 : tensor<2x3x4xi32>
}
// -----
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>

View File

@ -1,4 +1,4 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure
HloModule main
@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
// This test is more thorough than those of the the other binary ops to test
// their shared functionality.
// CHECK-LABEL: func @test_add
%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
// Add two tensors
// CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"}
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
// Add two scalars
// CHECK-NEXT: xla_hlo.add %arg2, %arg3
%add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4)
// Add a tensor and scalar
// CHECK-NEXT: "xla_hlo.add"(%0, %1)
ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4)
}
// CHECK-LABEL: func @test_after_all
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token
%test_after_all (token0: token[], token1: token[] ) -> token[] {
@ -159,11 +136,11 @@ add {
}
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> {
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] {
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> {
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] {
%Arg_0.1 = f32[3] parameter(0)
%Arg_1.2 = f32[3] parameter(1)
%Arg_2.3 = f32[1] parameter(2)
%Arg_2.3 = f32[3] parameter(2)
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
@ -172,7 +149,7 @@ add {
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
// Requires broadcast of compatible tensors.
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
}
@ -280,19 +257,19 @@ add {
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
}
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf64> {
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] {
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> {
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f64>
%convert.4 = f64[] convert(f32[] %Arg_1.2)
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
// CHECK-NEXT: "xla_hlo.add"(%0, %1)
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4)
// CHECK-NEXT: xla_hlo.add %0, %1
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
}
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {

View File

@ -163,8 +163,7 @@ struct HloBinaryElementwiseAdaptor {
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
/*broadcast_dimensions=*/nullptr);
broadcasted_lhs, broadcasted_rhs);
}
};
@ -183,9 +182,9 @@ struct HloCompareAdaptor {
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
/*broadcast_dimensions=*/nullptr, from_op.comparison_direction());
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
}
};

View File

@ -67,8 +67,9 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF &) {}
explicit LegalizeTF(bool allow_partial_conversion) {
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) {
allow_partial_conversion_ = allow_partial_conversion;
legalize_chlo_ = legalize_chlo;
}
/// Performs the lowering to XLA dialect.
@ -79,6 +80,11 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
*this, "allow-partial-conversion",
llvm::cl::desc("Allow operations that can't be legalized."),
llvm::cl::init(false)};
Option<bool> legalize_chlo_{
*this, "legalize-chlo",
llvm::cl::desc(
"Also legalizes intermediate chlo ops to hlo (default true)"),
llvm::cl::init(true)};
};
/// Returns if the given TF data format string is the default format.
@ -362,6 +368,154 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
}
// Deprecated: This is maintained to aid in porting old code that is not yet
// dynamic shape aware and uses broadcasting modes that CHLO does not support.
// Gets the resulting type from a broadcast between two types for statically
// shaped types. This is to be used for legacy lowerings that both use non
// left-padded broadcasting and static shapes. Its use should not be permitted
// in new code.
// May return nullptr on invalid static broadcast dimensions.
// ABSL_DEPRECATED()
static RankedTensorType GetStaticBroadcastType(
RankedTensorType x, RankedTensorType y,
DenseIntElementsAttr broadcast_dimensions_attr) {
auto element_type = x.getElementType();
auto shape_x = x.getShape();
auto shape_y = y.getShape();
if (shape_x.size() == shape_y.size()) {
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
for (int i = 0; i < shape_x.size(); i++) {
auto x_val = shape_x[i];
auto y_val = shape_y[i];
out_shape[i] = std::max(x_val, y_val);
}
return RankedTensorType::get(out_shape, element_type);
}
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
llvm::SmallVector<int64_t, 4> broadcast_dimensions;
// Explicit broadcast dimensions.
for (const APInt &int_value : broadcast_dimensions_attr) {
broadcast_dimensions.push_back(int_value.getSExtValue());
}
if (broadcast_dimensions.size() != shape_small.size()) {
return nullptr;
}
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
shape_large.end());
// Update according to the broadcast dimensions.
for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
auto old_value = out_shape[index_pair.value()];
auto new_value = shape_small[index_pair.index()];
out_shape[index_pair.value()] = std::max(old_value, new_value);
}
return RankedTensorType::get(out_shape, element_type);
}
// Deprecated: This is maintained to aid in porting old code that is not yet
// dynamic shape aware and uses broadcasting modes that CHLO does not support.
// Applies static binary broadcasting to a binary elementwise op.
// This is a legacy helper to provide general broadcasting support in legacy,
// static shaped code that relies on non-left-padded broadcasting semantics.
template <typename BinaryOp>
static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
DenseIntElementsAttr broadcast_dims,
OpBuilder &builder) {
auto x_type = x.getType().cast<RankedTensorType>();
auto y_type = y.getType().cast<RankedTensorType>();
auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
if (!result_type) {
emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
<< " with broadcast_dims = " << broadcast_dims;
return nullptr;
}
auto larger_broadcast_dims =
GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
if (x_type.getRank() < y_type.getRank()) {
if (x_type != result_type) {
x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
}
if (y_type != result_type) {
y = builder.create<BroadcastInDimOp>(loc, result_type, y,
larger_broadcast_dims);
}
} else {
if (x_type != result_type) {
x = builder.create<BroadcastInDimOp>(loc, result_type, x,
larger_broadcast_dims);
}
if (y_type != result_type) {
y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
}
}
return builder.create<BinaryOp>(loc, x, y);
}
// Gets a 1D tensor type suitable for expressing extents of the given tensor
// value type. If the value type is ranked, the result will be statically
// shaped. Otherwise, it will have a dynamic dimension.
static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
Builder b(value_type.getContext());
int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
return RankedTensorType::get({dim}, b.getIndexType());
}
// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value'
// by assuming that the shape of the lower ranked is a broadcast compatible
// prefix of the higher ranked.
// Values must be RankedTensorType (this restriction derives from the
// broadcast_dimensions attribute on DynamicBroadcastInDim).
//
// Example:
// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the
// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style
// implicit broadcasting).
static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value,
Value lower_rank_value, OpBuilder &builder) {
Value higher_rank_shape =
builder.create<shape::ShapeOfOp>(loc, higher_rank_value);
auto result_extents_type =
GetExtentsTensorTypeFor(higher_rank_value.getType().cast<TensorType>());
Value result_extents = builder.create<shape::ToExtentTensorOp>(
loc, result_extents_type, higher_rank_shape);
auto lower_rank_type = lower_rank_value.getType().cast<RankedTensorType>();
auto lower_rank = lower_rank_type.getRank();
auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder);
return builder.create<DynamicBroadcastInDimOp>(
loc, higher_rank_value.getType(), lower_rank_value, result_extents,
prefix_dims);
}
// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
// value (broadcast_from) along that feature dimension. This is a shortcut
// for the cases where a 1D tensor must be broadcast along a specific feature
// dimension, which can vary based on data layout, etc.
//
// The extent of `broadcast_from` dim0 must be equal to the extent of the
// feature_dim of `broadcast_to`.
//
// Example:
// [1x2x3x4], [2], 1 -> [1x2x3x4]
// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
// consistency. Possibly also rename for clarity.
static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
Value broadcast_from, int64_t feature_dim,
OpBuilder &builder) {
auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
auto to_type = broadcast_to.getType().cast<RankedTensorType>();
auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
auto result_extents_type = GetExtentsTensorTypeFor(to_type);
auto result_extents = builder.create<shape::ToExtentTensorOp>(
loc, result_extents_type, result_shape);
return builder.create<DynamicBroadcastInDimOp>(
loc, to_type, broadcast_from, result_extents, broadcast_dims);
}
// Creates a batch dot using xla_hlo::DotGeneralOp.
Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
bool transpose_rhs, int64_t num_batch_dims,
@ -407,8 +561,7 @@ static void BuildReduceBody(Type element_type, Region *body,
Location loc = body->getLoc();
auto reducer =
builder->create<Op>(loc, block->getArgument(0), block->getArgument(1),
/*broadcast_dimensions=*/nullptr);
builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
builder->create<ReturnOp>(loc, reducer.getResult());
}
@ -508,8 +661,7 @@ static void CreateWhile32(Location loc, int num_iterations,
loc, builder->getI32IntegerAttr(num_iterations));
StringAttr compare_direction = StringAttr::get("LT", builder->getContext());
Value compare = builder->create<xla_hlo::CompareOp>(
loc, loop_iv, upper_limit,
/*broadcast_dimensions=*/nullptr, compare_direction);
loc, loop_iv, upper_limit, compare_direction);
builder->create<xla_hlo::ReturnOp>(loc, compare);
}
@ -539,9 +691,9 @@ static void CreateWhile32(Location loc, int num_iterations,
// Increment the loop induction variable by one.
auto one =
builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
auto no_broadcast_dims = GetI64ElementsAttr({}, builder);
auto plus_one = builder->create<xla_hlo::AddOp>(loc, old_values[0], one,
no_broadcast_dims);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
auto plus_one = builder->create<xla_chlo::BroadcastAddOp>(
loc, old_values[0], one, scalar_broadcast_dims);
// Prepend with the updated loop induction variable.
new_values.insert(new_values.begin(), plus_one);
@ -566,21 +718,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
}
//===----------------------------------------------------------------------===//
// Bias op utilities.
//===----------------------------------------------------------------------===//
// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
// Requires input to have ranked tensor.
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
StringAttr format,
Value input) {
auto inputType = input.getType().cast<RankedTensorType>();
size_t featureDim = GetFeatureDimension(format, inputType);
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
return DenseIntElementsAttr::get(type, featureDim);
}
//===----------------------------------------------------------------------===//
// MatMul op utilities.
//===----------------------------------------------------------------------===//
@ -743,8 +880,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type,
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
Value compare = builder->create<CompareOp>(
loc, block->getArgument(0), block->getArgument(2),
/*broadcast_dimensions=*/nullptr, compare_direction);
loc, block->getArgument(0), block->getArgument(2), compare_direction);
Value selected_input = builder->create<SelectOp>(
loc, input_type, compare, block->getArgument(0), block->getArgument(2));
@ -860,8 +996,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
Value compare = builder->create<xla_hlo::CompareOp>(
loc, block->getArgument(0), block->getArgument(1),
/*broadcast_dimensions=*/nullptr, compare_direction);
loc, block->getArgument(0), block->getArgument(1), compare_direction);
builder->create<xla_hlo::ReturnOp>(loc, compare);
}
@ -900,6 +1035,27 @@ NamedAttribute GetConvDimensionNumbersAttr(
feature_dim, spatial_dims, builder->getContext()));
}
// Converts a TF::BiasAddOp to HLO.
// This differs from a normal TF::AddOp with respect to how the data_format
// is handled, which can optionally require a general broadcast of the
// 'bias' term in a way that is not compatible with the standard left-padded
// broadcast semantics (i.e. NCHW will broadcast into dimension 1).
// The correct 'bias' broadcast will be synthesized manually.
class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::BiasAddOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto feature_dim = GetFeatureDimension(
op.data_formatAttr(), op.value().getType().cast<RankedTensorType>());
auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
feature_dim, rewriter);
rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
return success();
}
};
// Converts the TensorFlow conv op in template to the generic HLO conv op by
// converting TensorFlow op attributes to HLO op attributes.
//
@ -1161,7 +1317,6 @@ class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
rewriter.getI64IntegerAttr(1));
Value compare = rewriter.create<CompareOp>(
op.getLoc(), iota0, iota1,
/*broadcast_dimensions=*/nullptr,
StringAttr::get("EQ", rewriter.getContext()));
Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
0, &rewriter);
@ -1274,33 +1429,35 @@ class ConvertFusedBatchNormGradBase
non_feature_dims.push_back(i);
}
auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter);
auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
// scratch1 = rsqrt(var + epsilon)
RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
auto epsilon = rewriter.create<ConstOp>(
loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(),
no_broadcast_dims);
auto add_op = rewriter.create<xla_chlo::BroadcastAddOp>(
loc, var, epsilon.getResult(), scalar_broadcast_dims);
Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
// scratch2 = sum(y_backprop * (x - mean))
auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims);
auto weighted_grad =
rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims);
auto sub_op = rewriter.create<xla_hlo::SubOp>(
loc, act,
Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
auto weighted_grad = rewriter.create<xla_hlo::MulOp>(loc, grad, sub_op);
Value scratch2 =
ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
// x_backprop = y_backprop * (scale * scratch1)
auto scaled_grad =
rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims);
x_backprop =
rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims);
rewriter.create<xla_hlo::MulOp>(loc, op.scale(), scratch1);
x_backprop = rewriter.create<xla_hlo::MulOp>(
loc, grad,
Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
rewriter));
// scale_backprop = scratch2 * scratch1
scale_backprop =
rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims);
scale_backprop = rewriter.create<xla_hlo::MulOp>(loc, scratch1, scratch2);
// offset_backprop = sum(y_backprop)
offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
@ -1396,7 +1553,7 @@ class ConvertFusedBatchNormV3Op
auto factor_const_op = rewriter.create<xla_hlo::ConstOp>(
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
Value corrected_variance = rewriter.create<xla_hlo::MulOp>(
Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), batch_variance.getType(), batch_variance,
factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
@ -1416,24 +1573,26 @@ class ConvertFusedBatchNormV3Op
rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
// new_running_mean = alpha * old_mean + beta * batch_mean.
auto alpha_mul_old_mean = rewriter.create<MulOp>(
auto alpha_mul_old_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), op.mean().getType(), alpha, op.mean(),
/*broadcast_dimensions=*/DenseIntElementsAttr());
auto beta_mul_batch_mean = rewriter.create<MulOp>(
auto beta_mul_batch_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), batch_mean.getType(), beta, batch_mean,
/*broadcast_dimensions=*/DenseIntElementsAttr());
batch_mean = rewriter.create<AddOp>(
batch_mean = rewriter.create<xla_chlo::BroadcastAddOp>(
op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
/*broadcast_dimensions=*/DenseIntElementsAttr());
// new_running_variance = alpha * old_variance + beta * batch_variance.
auto alpha_mul_old_variance = rewriter.create<MulOp>(
auto alpha_mul_old_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), op.variance().getType(), alpha, op.variance(),
/*broadcast_dimensions=*/DenseIntElementsAttr());
auto beta_mul_batch_variance = rewriter.create<MulOp>(
op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
auto beta_mul_batch_variance =
rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), corrected_variance.getType(), beta,
corrected_variance,
/*broadcast_dimensions=*/DenseIntElementsAttr());
corrected_variance = rewriter.create<AddOp>(
corrected_variance = rewriter.create<xla_chlo::BroadcastAddOp>(
op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
/*broadcast_dimensions=*/DenseIntElementsAttr());
}
@ -1586,10 +1745,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
// Divide by the number of elements in the window.
Value divisor =
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
auto batch_dims =
GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter);
Value result = rewriter.create<DivOp>(op.getLoc(), result_type, reduce,
divisor, batch_dims);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
Value result = rewriter.create<xla_chlo::BroadcastDivOp>(
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
// Convert back if we enlarged the element type's bitwidth.
if (input_element_type != sum_element_type)
@ -1759,16 +1917,14 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
op.getLoc(), type, scalar_one,
GetI64ElementsAttr(type.getShape(), &rewriter));
auto scaled_input = rewriter.create<MulOp>(
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
auto scaled_input =
rewriter.create<xla_hlo::MulOp>(op.getLoc(), operand, constant_ones);
auto tanh_op =
rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
auto mul_op =
rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
rewriter.create<xla_hlo::MulOp>(op.getLoc(), tanh_op, constant_ones);
auto add_op =
rewriter.create<AddOp>(op.getLoc(), mul_op, constant_ones,
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
rewriter.create<xla_hlo::AddOp>(op.getLoc(), mul_op, constant_ones);
rewriter.replaceOp(op, add_op.getResult());
return success();
@ -1807,20 +1963,18 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Value logits = op.logits();
// Softmax converter requires ranked type because the XLA reduce ops used
// while lowering requires dimensions attribute to reduce along.
// Note that the input and output shape is equivalent, so we use 'logits'
// and its type for shape calculations.
Value logits = op.logits();
RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
if (!type) return failure();
auto loc = op.getLoc();
int rank = type.getRank();
// Note that the TensorFlow Softmax op verifies that the input rank is
// greater than or equal to one so both of the following sequences are
// valid.
auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter);
// greater than or equal to one so the following sequence is valid.
auto reduce_dim = rewriter.create<TF::ConstOp>(
loc, GetI64ElementsAttr({rank - 1}, &rewriter));
@ -1833,8 +1987,10 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
auto max_logits =
rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
/*keep_dims=*/rewriter.getBoolAttr(false));
auto shifted_logits =
rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims);
auto max_logits_broadcast =
CommonPrefixBroadcast(loc, logits, max_logits, rewriter);
auto shifted_logits = rewriter.create<xla_hlo::SubOp>(loc, type, logits,
max_logits_broadcast);
// Exponentiate the inputs.
Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
@ -1847,9 +2003,12 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
if (use_log) {
Value log = rewriter.create<LogOp>(loc, sum);
rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims);
auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter);
rewriter.replaceOpWithNewOp<xla_hlo::SubOp>(op, shifted_logits,
log_broadcast);
} else {
rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims);
auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter);
rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op, exp, sum_broadcast);
}
return success();
}
@ -1896,7 +2055,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
auto dim = rewriter.create<GetDimensionSizeOp>(
op.getLoc(), result_type, input,
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
size = rewriter.create<MulOp>(
size = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), size->getResult(0), dim.getResult(),
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
}
@ -2582,10 +2741,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<MulOp>(
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, op.delta(),
xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
rewriter.replaceOpWithNewOp<AddOp>(
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
op, result_type, scaled, op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
return success();
@ -2633,7 +2792,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
int64_t num = (*num_attr.begin()).getSExtValue();
// Calculate the scaling that needs to be applied to the iota.
auto step_numerator = rewriter.create<SubOp>(
auto step_numerator = rewriter.create<xla_chlo::BroadcastSubOp>(
op.getLoc(), op.start().getType(), op.stop(), op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
Value step_denominator = rewriter.create<ConvertOp>(
@ -2641,11 +2800,11 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
if (num > 1) {
Value one = GetScalarConstOfType(result_type.getElementType(),
op.getLoc(), 1, &rewriter);
step_denominator = rewriter.create<SubOp>(
step_denominator = rewriter.create<xla_chlo::BroadcastSubOp>(
op.getLoc(), step_denominator.getType(), step_denominator, one,
xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
}
auto step = rewriter.create<DivOp>(
auto step = rewriter.create<xla_chlo::BroadcastDivOp>(
op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
xla::getBroadcastDimensionsAttr(&rewriter, step_numerator,
step_denominator));
@ -2653,10 +2812,10 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
// Scale the iota and add the offset.
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<MulOp>(
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, step,
xla::getBroadcastDimensionsAttr(&rewriter, iota, step));
rewriter.replaceOpWithNewOp<AddOp>(
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
op, result_type, scaled, op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
return success();
@ -2732,8 +2891,8 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
auto divisor = GetScalarConstOfType(reduce_element_type, loc,
divisor_count, &rewriter);
auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
result = rewriter.create<DivOp>(loc, result, divisor.getResult(),
broadcast_dims);
result = rewriter.create<xla_chlo::BroadcastDivOp>(
loc, result, divisor.getResult(), broadcast_dims);
}
result = rewriter.create<ConvertOp>(loc, result, element_type);
@ -3118,7 +3277,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
auto reducer = rewriter.create<CompareOp>(
loc, block->getArgument(0), block->getArgument(1),
/*broadcast_dimensions=*/nullptr,
StringAttr::get("GE", rewriter.getContext()));
rewriter.create<ReturnOp>(loc, reducer.getResult());
}
@ -3544,13 +3702,20 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
output_dims.insert(output_dims.begin() + axis, depth);
Location loc = op.getLoc();
// The iota result is the effective output shape of the computation,
// and indices must be broadcast into it. At this point, this computation
// would need to be reworked quite a bit to support dynamic shapes, so
// just using static broadcasting.
auto index_type = RankedTensorType::get(output_dims, element_type);
Value compare = rewriter.create<CompareOp>(
loc, op.indices(),
rewriter.create<IotaOp>(
loc, index_type,
IntegerAttr::get(rewriter.getIntegerType(64), axis)),
GetI64ElementsAttr(broadcast_dims, &rewriter),
auto iota = rewriter.create<IotaOp>(
loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
loc, index_type, op.indices(),
GetI64ElementsAttr(broadcast_dims, &rewriter));
Value compare = rewriter.create<xla_hlo::CompareOp>(
loc, broadcast_indices, iota,
StringAttr::get("EQ", rewriter.getContext()));
Value on_value = rewriter.create<BroadcastOp>(
loc, op.getType(), op.on_value(),
@ -4396,7 +4561,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
rewriter.getI64IntegerAttr(1));
Value compare = rewriter.create<CompareOp>(
op.getLoc(), iota0, iota1,
/*broadcast_dimensions=*/nullptr,
StringAttr::get("EQ", rewriter.getContext()));
Value identity_matrix =
rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
@ -4430,8 +4594,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
batch_dims.size(), precision_config, &rewriter);
a_update = BatchDot(op.getLoc(), y, false, a_update, false,
batch_dims.size(), precision_config, &rewriter);
a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update,
/*broadcast_dimensions=*/nullptr);
a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
&rewriter);
@ -4442,8 +4605,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
batch_dims.size(), precision_config, &rewriter);
q_update = BatchDot(op.getLoc(), q_update, false, y, true,
batch_dims.size(), precision_config, &rewriter);
q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update,
/*broadcast_dimensions=*/nullptr);
q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
}
// full_matrices is false when only a partial result in needed. Slice to the
@ -4505,34 +4667,31 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value iota = builder->create<IotaOp>(
loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
builder->getI64IntegerAttr(0));
Value gtk = builder->create<CompareOp>(
Value gtk = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota, k, GetI64ElementsAttr({}, builder),
StringAttr::get("GT", builder->getContext()));
gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
Value x_after_k = builder->create<MulOp>(
Value x_after_k = builder->create<xla_chlo::BroadcastMulOp>(
loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
Value x_after_k_sq = builder->create<MulOp>(
loc, x_after_k, x_after_k, /*broadcast_dimensions=*/nullptr);
Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
// sigma = np.dot(x[k+1:], x[k+1:])
auto sigma = builder->create<ReduceOp>(
loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
// mu = np.sqrt(x[k]*x[k] + sigma)
Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha,
/*broadcast_dimensions=*/nullptr);
Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
Value mu = builder->create<SqrtOp>(
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0),
/*broadcast_dimensions=*/nullptr));
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
Value sigma_is_zero = builder->create<CompareOp>(
Value sigma_is_zero = builder->create<xla_chlo::BroadcastCompareOp>(
loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
Value alpha_is_negative = builder->create<CompareOp>(
Value alpha_is_negative = builder->create<xla_chlo::BroadcastCompareOp>(
loc, alpha, zero, GetI64ElementsAttr({}, builder),
StringAttr::get("LT", builder->getContext()));
auto batch_size_one = builder->create<BroadcastOp>(
loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
Value signed_mu = builder->create<MulOp>(
Value signed_mu = builder->create<xla_chlo::BroadcastMulOp>(
loc,
builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
batch_size_one,
@ -4541,21 +4700,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
*beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
alpha, signed_mu);
*tau = builder->create<DivOp>(
loc,
builder->create<SubOp>(loc, *beta, alpha,
/*broadcast_dimensions=*/nullptr),
*beta,
/*broadcast_dimensions=*/nullptr);
loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
Value zero_tau = builder->create<BroadcastOp>(
loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
*tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
zero_tau, *tau);
Value divisor = builder->create<SubOp>(loc, alpha, *beta,
/*broadcast_dimensions=*/nullptr);
Value divisor = builder->create<SubOp>(loc, alpha, *beta);
divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
batch_size_one, divisor);
Value eqk = builder->create<CompareOp>(
Value eqk = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota, k, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
@ -4568,10 +4722,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
*v = builder->create<AddOp>(
// Note that the add performs a degenerate broadcast.
*v = builder->create<xla_chlo::BroadcastAddOp>(
loc, e_k,
builder->create<DivOp>(loc, x_after_k, divisor,
GetI64ElementsAttr(batch_dim_ids, builder)),
StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
GetI64ElementsAttr(batch_dim_ids, builder),
*builder),
/*broadcast_dimensions=*/nullptr);
}
@ -4645,10 +4801,10 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
precision, builder);
vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
precision, builder);
auto tau_x_vva = builder->create<MulOp>(
loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder));
a = builder->create<SubOp>(loc, a, tau_x_vva,
/*broadcast_dimensions=*/nullptr);
auto tau_x_vva = StaticBinaryBroadcast<xla_hlo::MulOp>(
loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
*builder);
a = builder->create<SubOp>(loc, a, tau_x_vva);
// It is more precise to populate column 'k' explicitly, rather than
// computing it implicitly by applying the Householder transformation.
@ -4657,12 +4813,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
auto iota = builder->create<IotaOp>(
loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
builder->getI64IntegerAttr(0));
Value predecessor_mask = builder->create<CompareOp>(
Value predecessor_mask = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota, j, GetI64ElementsAttr({}, builder),
StringAttr::get("LT", builder->getContext()));
predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
a_type.getElementType());
Value mask = builder->create<CompareOp>(
Value mask = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
@ -4674,14 +4830,14 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
mask,
GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
builder));
Value predecessor_masked_x = builder->create<MulOp>(
Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
loc, x, predecessor_mask,
GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder));
Value masked_beta = builder->create<MulOp>(
loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder));
GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
Value masked_beta = StaticBinaryBroadcast<MulOp>(
loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
*builder);
Value new_x =
builder->create<AddOp>(loc, predecessor_masked_x, masked_beta,
/*broadcast_dimensions=*/nullptr);
builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
// Update a[:,j]
llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
std::iota(dim_ids.begin(), dim_ids.end(), 0);
@ -4692,7 +4848,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc,
RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
builder->getI64IntegerAttr(minor_dim + 1));
Value xa_mask = builder->create<CompareOp>(
Value xa_mask = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
@ -4708,11 +4864,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
builder));
auto vs_update = builder->create<SelectOp>(
loc, vs.getType(), xa_mask,
builder->create<AddOp>(
loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder)),
StaticBinaryBroadcast<AddOp>(
loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
*builder),
vs_zeros);
vs = builder->create<AddOp>(loc, vs, vs_update,
/*broadcast_dimensions=*/nullptr);
vs = builder->create<AddOp>(loc, vs, vs_update);
// taus[j] = tau
llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
@ -4729,17 +4885,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, taus.getType(), taus_zeros,
GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
builder));
Value taus_mask = builder->create<CompareOp>(
Value taus_mask = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota_n, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
auto taus_update = builder->create<SelectOp>(
loc, taus.getType(), taus_mask,
builder->create<AddOp>(
StaticBinaryBroadcast<AddOp>(
loc, taus_zeros, tau,
GetI64ElementsAttr(tau_broadcast_dims, builder)),
GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
taus_zeros);
taus = builder->create<AddOp>(loc, taus, taus_update,
/*broadcast_dimensions=*/nullptr);
taus = builder->create<AddOp>(loc, taus, taus_update);
new_values->assign({a, vs, taus});
};
@ -4796,8 +4951,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
j = builder->create<AddOp>(
loc, j,
GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
builder),
/*broadcast_dimensions=*/nullptr);
builder));
// vs has shape [..., m, 1]
auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
// beta has shape [..., 1]
@ -4816,7 +4970,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, vs.getType(), zero,
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
builder));
auto compare = builder->create<CompareOp>(
auto compare = builder->create<xla_chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
StringAttr::get("GE", builder->getContext()));
auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
@ -4831,13 +4985,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
// z = -beta * (v + wyv)
auto neg_beta = builder->create<NegOp>(loc, beta);
auto v_wyv = builder->create<AddOp>(loc, v, wyv,
/*broadcast_dimensions=*/nullptr);
auto v_wyv = builder->create<AddOp>(loc, v, wyv);
auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
beta_broadcast_dims.push_back(n_index);
auto z = builder->create<MulOp>(
auto z = StaticBinaryBroadcast<MulOp>(
loc, neg_beta, v_wyv,
GetI64ElementsAttr(beta_broadcast_dims, builder));
GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
new_values->assign({w, vs, taus});
@ -4855,8 +5008,9 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
auto neg_beta = rewriter->create<NegOp>(loc, beta);
auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
beta_broadcast_dims.push_back(n_index);
auto bv = rewriter->create<MulOp>(
loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter));
auto bv = StaticBinaryBroadcast<MulOp>(
loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
*rewriter);
w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
SmallVector<Value, 4> while_output;
@ -4912,7 +5066,8 @@ void EmitLegalizationErrors(Operation *op,
// Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
if (failed(
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_)))
signalPassFailure();
}
@ -4923,7 +5078,8 @@ static PassRegistration<LegalizeTF> pass(
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
bool legalize_chlo) {
MLIRContext *context = op->getContext();
// Add lowering patterns to the list.
@ -4936,19 +5092,19 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
TF::PopulateLoweringTFPatterns(context, &patterns);
patterns.insert<
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp,
ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp,
ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp,
ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp,
ConvertEinsumOp, ConvertFusedBatchNormGradOp,
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
@ -4959,10 +5115,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
// CHLO first.
if (legalize_chlo) {
xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
}
ConversionTarget target(*context);
if (legalize_chlo) {
target.addIllegalDialect<xla_chlo::XlaHloClientDialect>();
} else {
target.addLegalDialect<xla_chlo::XlaHloClientDialect>();
}
target.addLegalDialect<XlaHloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<shape::ShapeDialect>();
@ -4988,8 +5150,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion) {
return std::make_unique<LegalizeTF>(allow_partial_conversion);
bool allow_partial_conversion, bool legalize_chlo) {
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
}
} // end namespace xla_hlo

View File

@ -73,21 +73,6 @@ def : Pattern<
// HLO and XLA doesn't support Assertions.
def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>;
//===----------------------------------------------------------------------===//
// Bias op patterns.
//===----------------------------------------------------------------------===//
def BiasAddFeatureDimension : NativeCodeCall<
"getBiasFeatureDimension($_builder, $0, $1)">;
// $input needs to be a ranked tensor to identify index of the feature
// dimension depending on the data_format 'NHWC' or 'NCHW'.
// TODO(laurenzo): This should be converted to do explicit broadcasting since
// it can generate broadcast dimensions that are not compatible with the simple
// xla_chlo.add broadcast_dims.
def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format),
(HLO_AddOp $input, $bias,
(BiasAddFeatureDimension $data_format, $input))>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
@ -114,7 +99,8 @@ foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],
def LowerRightShiftSigned :
Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)),
(HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
(BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $r)]>;
// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op
@ -126,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
//
// return floor(div(x, y))
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))),
(HLO_FloorOp
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
[(IEEEFloatTensor $l)]>;
// Performs a substitution of FloorDir for integer tensors, which required
// Performs a substitution of FloorDiv for integer tensors, which required
// additional correction for a negative numerator / denominator. Equivalent
// pseudocode is shown below:
//
@ -150,16 +137,16 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
// broadcast attributes.
def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
(HLO_SelectOp
(HLO_CompareOp
(HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)),
(HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ),
(HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)),
(HLO_DivOp
(HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l),
(HLO_SubOp (HLO_AbsOp $r),
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
(HLOClient_BroadcastDivOp
(HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
(HLOClient_BroadcastSubOp (HLO_AbsOp $r),
(HLO_ConstOp (ConstantSplat<"1"> $r)),
(NullDenseIntElementsAttr)),
(BinBroadcastDimensions $l, $r))),
@ -175,20 +162,20 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
// broadcast attributes.
def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
(HLOClient_BroadcastAndOp
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
(HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)),
(BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE),
(HLO_CompareOp
(HLO_CompareOp:$r_cmp $r,
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp:$r_cmp $r,
(HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(HLO_CompareOp:$rem_cmp $rem, $r_zeros,
(HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros,
(BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT),
(BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE),
(NullDenseIntElementsAttr)),
(HLO_AddOp $r,
(HLOClient_BroadcastAddOp $r,
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
//===----------------------------------------------------------------------===//
@ -406,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
(HLO_SelectOp:$num_lower_or_m
(HLO_CompareOp
$num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT
HLO_COMPARISON_DIRECTION_LT
),
$m_dim,
$num_lower
),
(HLO_SelectOp:$num_upper_or_n
(HLO_CompareOp
$num_upper, $zero,
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT
$num_upper, $zero, HLO_COMPARISON_DIRECTION_LT
),
$n_dim,
$num_upper
),
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLOClient_BroadcastCompareOp
(HLO_NegOp
(createConvertOp $op, $num_lower_or_m, $input)
),
(HLO_SubOp:$offset
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input),
(NullDenseIntElementsAttr)
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input)
),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
),
(HLO_CompareOp
(HLOClient_BroadcastCompareOp
$offset,
(createConvertOp
$op, $num_upper_or_n, $input
),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
),
(BinBroadcastDimensions $offset, $input)
)
),
$input,
(HLO_ConstOp (ConstantSplat<"0"> $input))
@ -462,7 +446,8 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value),
// TODO(hinsu): Lower unsigned and quantized types after supporting
// them in GetScalarOfType.
def : Pat<(TF_ReluOp AnyRankedTensor:$input),
(HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
(HLOClient_BroadcastMaxOp
(HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
(BinBroadcastDimensions $zero, $input)),
[(TF_SintOrFpTensor $input)]>;
@ -485,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input),
// to create splat tensor of dynamic shape in HLO.
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
(HLO_SelectOp
(HLO_CompareOp $features,
(HLOClient_BroadcastCompareOp $features,
(HLO_ConstOp (GetScalarOfType<0> $features)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
$gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
@ -598,7 +583,6 @@ def : Pat<(TF_SignOp $x),
(HLO_CompareOp
$x,
$x,
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_NE
),
(HLO_ConstOp (ConstantSplat<"0"> $x)),
@ -641,8 +625,6 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
//===----------------------------------------------------------------------===//
def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_MulOp
(HLO_MulOp $r, $l, (NullDenseIntElementsAttr)),
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l,
(NullDenseIntElementsAttr)),
(NullDenseIntElementsAttr)),
(HLO_MulOp $r, $l),
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l)),
[(IEEEFloatTensor $l)]>;

View File

@ -36,47 +36,36 @@ def IsSameSizePred : CPred<
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
(AndOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(AddFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(SubFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(MulFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(DivFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(RemFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(AddIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SubIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(MulIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SignedDivIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SignedRemIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;

View File

@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
// and imaginary components.
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs),
$broadcast_dimensions),
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs),
$broadcast_dimensions))>;
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
// Complex multiplication results in a cross product multiplication between the
// real and imaginary components such that:
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_SubOp
(HLO_MulOp
(HLO_RealOp:$lhs_real $lhs),
(HLO_RealOp:$rhs_real $rhs),
$broadcast_dimensions),
(HLO_RealOp:$rhs_real $rhs)),
(HLO_MulOp
(HLO_ImagOp:$lhs_imag $lhs),
(HLO_ImagOp:$rhs_imag $rhs),
$broadcast_dimensions),
(NullDenseIntElementsAttr)),
(HLO_ImagOp:$rhs_imag $rhs))),
(HLO_AddOp
(HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions),
(HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions),
(NullDenseIntElementsAttr)))>;
(HLO_MulOp $lhs_real, $rhs_imag),
(HLO_MulOp $lhs_imag, $rhs_real)))>;
// Multiplication between a complex and real tensor can be distributed by
// applying the real multiplicant to both the real and complex component.
//
// Note that the sourcep pattern is not legal according to the HLO dialect but
// instead handle intermediates generated by other patterns.
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
(HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions),
(HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>;
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
// numerator = lhs * conj(rhs)
// denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_DivOp
(HLO_MulOp:$num $lhs,
(HLO_ComplexOp:$conj
(HLO_RealOp $rhs),
(HLO_NegOp (HLO_ImagOp $rhs))),
$broadcast_dimensions),
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)),
(BinBroadcastDimensions $num, $den))>;
(HLO_NegOp (HLO_ImagOp $rhs)))),
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
(HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
// Absolute value is evaluated as:
@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real,
(NullDenseIntElementsAttr)),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag,
(NullDenseIntElementsAttr)),
(NullDenseIntElementsAttr))),
(HLO_MulOp (HLO_RealOp:$real $val), $real),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
// Exponential can be lowered to an exponential on the real component and a
@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_ExpOp (HLO_RealOp $val)),
(HLO_ComplexOp
(HLO_CosOp (HLO_ImagOp:$imag $val)),
(HLO_SinOp $imag)),
(NullDenseIntElementsAttr))>;
(HLO_SinOp $imag)))>;

View File

@ -28,259 +28,6 @@ namespace xla_hlo {
namespace {
// Returns a 1-d i64 elements attribute populated with numbers from start to
// end, excluding.
static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
Builder *builder) {
int size = end - start;
SmallVector<int64_t, 4> vals;
vals.resize(size);
std::iota(vals.begin(), vals.end(), start);
TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(ty, vals);
}
// Helper function for OpRewritePattern classes to materialize broadcasts on
// LHS and RHS arguments to a binary op.
//
// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful,
// returns false otherwise.
template <typename SrcOp>
bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
Value *out_lhs, Value *out_rhs) {
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
// replacing the original LHS and RHS args in the source op with the results
// of the broadcasts.
//
// If the higher dimensional argument does not actually need the broadcast,
// a canonicalization pass should be able to remove that op later.
Value lhs = op.lhs();
Value rhs = op.rhs();
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) {
// Unranked, can't determine at this point how to perform the broadcast.
return false;
}
// Dynamic result shape, can't use BroadcastInDimOp.
assert(op_ranked_type.hasStaticShape() &&
"dynamic shape requires DynamicBroadcastInDim");
auto lhs_rank = lhs_ranked_type.getRank();
auto rhs_rank = rhs_ranked_type.getRank();
ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
// BroadcastInDimOp must have the same element type for operands and results,
// so preserve the original output shape and the original input element type.
// For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`:
// broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
// broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
// SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) {
auto type =
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter);
if (lhs_rank < rhs_rank) {
attr = op.broadcast_dimensions().getValue();
}
lhs =
rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, lhs, attr);
}
if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) {
auto type =
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter);
if (rhs_rank < lhs_rank) {
attr = op.broadcast_dimensions().getValue();
}
rhs =
rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, rhs, attr);
}
*out_lhs = lhs;
*out_rhs = rhs;
return true;
}
// Helper template to generate code for computing the result shape of a
// broadcasted operation. This ultimately should be subsumed by functions
// from the shape dialect.
// Assumes that large and small are the operand values of `op` and that they
// have a ranked tensory type with rank(large) >= rank(small).
template <typename SrcOp>
std::vector<Value> ComputeBroadcastedShape(SrcOp op, Value small, Value large,
PatternRewriter *rewriter) {
auto loc = op.getLoc();
auto larger_ranked_type = large.getType().cast<RankedTensorType>();
auto output_rank = larger_ranked_type.getRank();
constexpr int kExpandShape = -1;
std::vector<Value> shape_values;
shape_values.reserve(output_rank);
std::vector<int> indexes(output_rank, kExpandShape);
DenseIntElementsAttr broadcast_dimensions =
op.broadcast_dimensions().getValue();
// Compute a mapping from output dimensions to their corresponding input
// dimensions in the smaller ranked operand.
for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) {
indexes.at(pair.value().getLimitedValue()) = pair.index();
}
// Compute the broadcasted shape of the result using numpy style broadcasting
// semantics. The result shape at a position is the shape of the larger
// operand at that position if the no dimension of the smaller operand is
// mapped to it.
// If both operands contribute to an output dimension, their shape has to
// either be the same in that dimension or it can be 1, in which case the
// shape of the other operand is used.
for (int i = 0; i < output_rank; ++i) {
if (indexes[i] == kExpandShape) {
// The smaller shape gets expanded to the larger one in this case.
shape_values.push_back(rewriter->create<mlir::DimOp>(loc, large, i));
continue;
}
// Compute the result shape depending on whether the rank of smaller is 1.
// This does not check that the broadcast operation actualy is correct.
// In particular, we do not check that both shapes are the same if the
// smaller ranked shape is not 1.
ConstantOp one = rewriter->create<mlir::ConstantOp>(
loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1));
DimOp lrg_dim = rewriter->create<mlir::DimOp>(loc, large, i);
DimOp sml_dim = rewriter->create<mlir::DimOp>(loc, small, indexes[i]);
CmpIOp compare =
rewriter->create<mlir::CmpIOp>(loc, CmpIPredicate::eq, lrg_dim, one);
shape_values.push_back(
rewriter->create<mlir::SelectOp>(loc, compare, lrg_dim, sml_dim));
}
return shape_values;
}
// Helper function for OpRewritePattern classes to materialize dynamic
// broadcasts on LHS and RHS arguments to a binary op.
//
// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts
// for LHS and RHS arguments, else returns false.
template <typename SrcOp>
bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
Value *out_lhs, Value *out_rhs) {
if (!op.broadcast_dimensions().hasValue()) {
// Note: the op may still have an implicit broadcast on it, such as
// for (tensor<1xf32>, tensor<4xf32>).
return false;
}
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
// replacing the original LHS and RHS args in the source op with the results
// of the broadcasts.
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_ranked_type || !rhs_ranked_type) {
// Unranked, can't determine at this point how to perform the broadcast.
return false;
}
auto lhs_rank = lhs_ranked_type.getRank();
auto rhs_rank = rhs_ranked_type.getRank();
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
// Use the original op.broadcast_dimensions for the lower rank arg.
auto higher_rank_broadcast_dims =
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
DenseIntElementsAttr lhs_broadcast_dims;
DenseIntElementsAttr rhs_broadcast_dims;
std::vector<Value> shape_elements;
if (lhs_rank > rhs_rank) {
lhs_broadcast_dims = higher_rank_broadcast_dims;
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
shape_elements = ComputeBroadcastedShape<SrcOp>(op, rhs, lhs, rewriter);
} else if (lhs_rank < rhs_rank) {
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
rhs_broadcast_dims = higher_rank_broadcast_dims;
shape_elements = ComputeBroadcastedShape<SrcOp>(op, lhs, rhs, rewriter);
} else {
// This shouldn't happen for legal ops. If the broadcast_dimensions
// attribute is set, the ranks should be different.
// TODO(scotttodd): Add a custom verification for ops and assert here.
return false;
}
// DynamicBroadcastInDimOp preserves the element type but produces a tensor
// with unranked shape. The rank of the output is the length of the
// output shape argument.
SmallVector<int64_t, 4> op_shape(shape_elements.size(),
RankedTensorType::kDynamicSize);
auto lhs_type =
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
auto rhs_type =
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
// We need a way to turn a list of scalars into a vector. While Standard
// dialect does not have one, use the XLA_HLO variant.
int shape_size = shape_elements.size();
Type shape_element_type = shape_elements.front().getType();
Value shape_value = rewriter->create<ScalarsToDimensionTensorOp>(
op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type),
shape_elements);
*out_lhs = rewriter->createOrFold<DynamicBroadcastInDimOp>(
op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims);
*out_rhs = rewriter->createOrFold<DynamicBroadcastInDimOp>(
op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims);
return true;
}
template <typename SrcOp>
bool CreateBroadcastForBinaryOp(SrcOp op, PatternRewriter *rewriter,
Value *out_lhs, Value *out_rhs) {
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
if (!op_ranked_type) return false;
if (op_ranked_type.hasStaticShape()) {
if (!CreateStaticBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) {
return false;
}
} else {
if (!CreateDynamicBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) {
return false;
}
}
return true;
}
template <typename SrcOp>
struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
explicit BinaryOpWithBroadcastConvert(MLIRContext *context)
: OpRewritePattern<SrcOp>(context) {}
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
Value new_lhs;
Value new_rhs;
if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs))
return failure();
// Replace the original op with a new one that uses the new args.
// New args are broadcasts, so no dims are needed on the replacement op.
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), new_lhs, new_rhs,
/*broadcast_dims=*/nullptr);
return success();
}
};
// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays
// must be the same shape. Alternatively, as a restricted form of broadcasting,
// min and/or max can be a scalar of type T."
@ -322,63 +69,10 @@ struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> {
}
};
// Specialized class for CompareOp, as it has an additional builder argument.
struct CompareWithBroadcastConvert : public OpRewritePattern<CompareOp> {
explicit CompareWithBroadcastConvert(MLIRContext *context)
: OpRewritePattern<CompareOp>(context) {}
LogicalResult matchAndRewrite(CompareOp op,
PatternRewriter &rewriter) const override {
Value new_lhs;
Value new_rhs;
if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs))
return failure();
rewriter.replaceOpWithNewOp<CompareOp>(op, op.getType(), new_lhs, new_rhs,
/*broadcast_dims=*/nullptr,
op.comparison_direction());
return success();
}
};
} // namespace
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
ConversionTarget *conversionTarget) {
#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
conversionTarget->addDynamicallyLegalOp<OpType>([](OpType op) { \
if (op.broadcast_dimensions().hasValue()) return false; \
auto l = op.lhs().getType().cast<ShapedType>(); \
auto r = op.rhs().getType().cast<ShapedType>(); \
if (!l.hasRank() || !r.hasRank()) return false; \
return l.getShape() == r.getShape(); \
});
// Binary elementwise ops.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp);
// Binary logical elementwise ops.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp);
// CompareOp.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp);
#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST
conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) {
return op.max().getType() == op.operand().getType() &&
op.min().getType() == op.operand().getType();
@ -387,30 +81,10 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// Binary elementwise ops.
patterns->insert<BinaryOpWithBroadcastConvert<AddOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<Atan2Op>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<DivOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MaxOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MinOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MulOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<PowOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<RemOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftLeftOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightArithmeticOp>>(
context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightLogicalOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<SubOp>>(context);
// Binary logical elementwise ops.
patterns->insert<BinaryOpWithBroadcastConvert<AndOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<OrOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<XorOp>>(context);
// ClampOp. It can have a restricted form of broadcasting.
// ClampOp. This op has a special case where it accepts either same-shaped
// inputs or scalars (a restricted form of broadcasting). This makes the
// broadcast explicit.
patterns->insert<ClampWithBroadcastConvert>(context);
// CompareOp. Note the specialized class instead of using the template.
patterns->insert<CompareWithBroadcastConvert>(context);
}
} // namespace xla_hlo

View File

@ -36,7 +36,7 @@ namespace xla_hlo {
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
/// false, emits an error if there is any operation that can't be legalized.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion = false);
bool allow_partial_conversion = false, bool legalize_chlo = true);
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
/// specified device type.
@ -50,7 +50,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
/// dialect using the conversion patterns registered by the HLO dialect. When
/// allow_partial_conversion is false, emits an error if there is any operation
/// that can't be legalized.
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false);
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false,
bool legalize_chlo = true);
/// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();

View File

@ -135,8 +135,8 @@ class UnfuseBatchNormInferencePattern
if (!epsilon) {
return failure();
}
Value stddev = rewriter.create<xla_hlo::AddOp>(
bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr);
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
bn_op.variance(), epsilon);
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
// Broadcast all terms.
@ -160,13 +160,13 @@ class UnfuseBatchNormInferencePattern
// Compute:
// scale * (input - mean) / stddev + offset
Value result = rewriter.create<xla_hlo::SubOp>(
bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr);
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
broadcast_scale, nullptr);
broadcast_scale);
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
broadcast_stddev, nullptr);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result, broadcast_offset,
nullptr);
broadcast_stddev);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
broadcast_offset);
return success();
}