Remove implicit broadcasting from xla_hlo binary elementwise ops.
* Migrates legalize-tf conversions to either: * convert through the chlo.broadcast_* ops (majority) * have special case broadcasting for non-supported or non-optimal broadcast modes * This was done one by one, qc'ing each and many bugs/inefficiencies/ambiguous broadcasting modes were corrected. * Looks like it may be missing a rule for legalizing complex types (will check on Monday). * I considered splitting this up, but it was actually pretty important to make the ops more strict to flush out all cases (best done as an atomic change). * Stricter conversions fixed a number of cases where shapes were dropping to unranked or unknown (and needn't be). * With this change, most of the binary ops and many of the resulting tf2xla expansions correctly support dynamic shapes via the shape dialect. * I verified this with the small set of IREE test cases which support dynamic shapes and will expand coverage once this lands. * This is some test fallout outside of the xla directory that I will fixup on Monday. PiperOrigin-RevId: 312316083 Change-Id: I6d246d80cddb84f2dfd62817c7166f53c1f6cdec
This commit is contained in:
parent
273617ad91
commit
b7735095de
@ -2,17 +2,17 @@
|
||||
|
||||
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
return %0 : tensor<?x?x?x?xi32>
|
||||
}
|
||||
|
||||
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0 : tensor<4x4x4x4xi32>
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
|
||||
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
|
||||
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
|
||||
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
|
||||
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %2 : tensor<2x3xf16>
|
||||
}
|
||||
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
|
||||
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %3 : tensor<4x8xf32>
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -18,12 +18,16 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
// Note that these are legalized from chlo.broadcast_* ops, since those are
|
||||
// semantically compatible with the corresponding TF ops. Depending on
|
||||
// context, getting to these ops may require some raising.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Check that two values can be broadcasted together
|
||||
@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
|
||||
"types must be broadcastable">;
|
||||
|
||||
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, TF_DivOp],
|
||||
[HLO_ShiftLeftOp, TF_LeftShiftOp],
|
||||
[HLO_MaxOp, TF_MaximumOp],
|
||||
[HLO_MinOp, TF_MinimumOp],
|
||||
[HLO_MulOp, TF_MulOp],
|
||||
[HLO_PowOp, TF_PowOp],
|
||||
[HLO_SubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, TF_ModOp]] in
|
||||
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
|
||||
foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp],
|
||||
[HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp],
|
||||
[HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp],
|
||||
[HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp],
|
||||
[HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp],
|
||||
[HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp],
|
||||
[HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in {
|
||||
def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
|
||||
def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, TF_BitwiseXorOp]] in
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
|
||||
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in {
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>;
|
||||
def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, TF_LogicalOrOp]] in
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
|
||||
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in {
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>;
|
||||
def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>;
|
||||
def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>;
|
||||
def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>;
|
||||
def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
|
||||
@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim),
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compare op patterns.
|
||||
// Note that these are legalized from chlo.broadcast_* ops, since those are
|
||||
// semantically compatible with the corresponding TF ops. Depending on
|
||||
// context, getting to these ops may require some raising.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in {
|
||||
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>;
|
||||
}
|
||||
|
||||
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
|
||||
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
|
||||
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in {
|
||||
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>;
|
||||
}
|
||||
|
@ -185,6 +185,16 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||
// BroadcastCompareOp (has custom type inference due to different result type).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
|
||||
Value lhs, Value rhs,
|
||||
DenseIntElementsAttr broadcast_dimensions,
|
||||
StringAttr comparison_direction) {
|
||||
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
|
||||
builder.getI1Type(), broadcast_dimensions);
|
||||
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||
comparison_direction);
|
||||
}
|
||||
|
||||
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
|
@ -360,6 +360,11 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||
);
|
||||
let results = (outs HLO_PredTensor);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
|
||||
>];
|
||||
}
|
||||
|
||||
#endif // CHLO_OPS
|
||||
|
@ -1401,89 +1401,25 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Gets the resulting type from a broadcast between two types.
|
||||
static Type GetBroadcastType(Builder* builder, Type x, Type y,
|
||||
Type element_type,
|
||||
DenseIntElementsAttr broadcast_dimensions) {
|
||||
|
||||
// Updates the element type of a (presumed) tensor type 'x', returning either
|
||||
// a permuted UnrankedTensorType or RankedTensorType.
|
||||
static Type UpdateResultElementType(Builder* builder, Type x,
|
||||
Type element_type) {
|
||||
auto x_ranked = x.dyn_cast<RankedTensorType>();
|
||||
auto y_ranked = y.dyn_cast<RankedTensorType>();
|
||||
if (!x_ranked || !y_ranked) {
|
||||
if (!x_ranked) {
|
||||
return UnrankedTensorType::get(element_type);
|
||||
}
|
||||
|
||||
auto shape_x = x_ranked.getShape();
|
||||
auto shape_y = y_ranked.getShape();
|
||||
|
||||
if (shape_x.size() == shape_y.size()) {
|
||||
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
|
||||
for (int i = 0; i < shape_x.size(); i++) {
|
||||
auto x_val = shape_x[i];
|
||||
auto y_val = shape_y[i];
|
||||
if (x_val == -1 || y_val == -1) {
|
||||
out_shape[i] = -1;
|
||||
} else {
|
||||
out_shape[i] = std::max(x_val, y_val);
|
||||
}
|
||||
}
|
||||
return RankedTensorType::get(out_shape, element_type);
|
||||
}
|
||||
|
||||
// Return unranked tensor for invalid broadcast dimensions.
|
||||
if (!broadcast_dimensions) return UnrankedTensorType::get(element_type);
|
||||
|
||||
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
|
||||
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
|
||||
|
||||
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
|
||||
shape_large.end());
|
||||
|
||||
// Update according to the broadcast dimensions.
|
||||
for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) {
|
||||
auto old_value = out_shape[index_pair.value().getSExtValue()];
|
||||
auto new_value = shape_small[index_pair.index()];
|
||||
if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
|
||||
out_shape[index_pair.value().getSExtValue()] = new_value;
|
||||
}
|
||||
}
|
||||
|
||||
return RankedTensorType::get(out_shape, element_type);
|
||||
return RankedTensorType::get(shape_x, element_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define BINARY_BUILDER(Op) \
|
||||
void Op::build(OpBuilder& builder, OperationState& result, Value left, \
|
||||
Value right, DenseIntElementsAttr broadcast_dimensions) { \
|
||||
auto type = GetBroadcastType(&builder, left.getType().cast<ShapedType>(), \
|
||||
right.getType().cast<ShapedType>(), \
|
||||
getElementTypeOrSelf(right.getType()), \
|
||||
broadcast_dimensions); \
|
||||
return Op::build(builder, result, type, left, right, \
|
||||
broadcast_dimensions); \
|
||||
}
|
||||
|
||||
BINARY_BUILDER(AddOp);
|
||||
BINARY_BUILDER(AndOp);
|
||||
BINARY_BUILDER(Atan2Op);
|
||||
BINARY_BUILDER(DivOp);
|
||||
BINARY_BUILDER(MaxOp);
|
||||
BINARY_BUILDER(MinOp);
|
||||
BINARY_BUILDER(MulOp);
|
||||
BINARY_BUILDER(OrOp);
|
||||
BINARY_BUILDER(PowOp);
|
||||
BINARY_BUILDER(RemOp);
|
||||
BINARY_BUILDER(ShiftLeftOp);
|
||||
BINARY_BUILDER(ShiftRightArithmeticOp);
|
||||
BINARY_BUILDER(ShiftRightLogicalOp);
|
||||
BINARY_BUILDER(SubOp);
|
||||
BINARY_BUILDER(XorOp);
|
||||
|
||||
#undef BINARY_BUILDER
|
||||
|
||||
template <typename Op, typename ElementType = Type, typename ValType,
|
||||
typename Convert>
|
||||
static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||
if (!attrs[0] || !attrs[1]) return {};
|
||||
if (op->broadcast_dimensions().hasValue()) return {};
|
||||
|
||||
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
|
||||
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
|
||||
@ -1893,12 +1829,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
||||
Value rhs, DenseIntElementsAttr broadcast_dimensions,
|
||||
StringAttr comparison_direction) {
|
||||
auto new_type = GetBroadcastType(&builder, lhs.getType(), rhs.getType(),
|
||||
builder.getI1Type(), broadcast_dimensions);
|
||||
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||
comparison_direction);
|
||||
Value rhs, StringAttr comparison_direction) {
|
||||
auto new_type =
|
||||
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
|
||||
build(builder, result, new_type, lhs, rhs, comparison_direction);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@ -241,15 +241,9 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||
HLO_Tensor:$rhs
|
||||
);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value left, Value right, "
|
||||
"DenseIntElementsAttr broadcast_dimensions"
|
||||
>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
@ -270,15 +264,15 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
}
|
||||
|
||||
def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp {
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
||||
|
||||
def HLO_ComplexOp: HLO_Op<"complex",
|
||||
[NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>,
|
||||
[NoSideEffect, SameOperandsAndResultShape]>,
|
||||
BASE_HLO_ComplexOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
||||
@ -289,39 +283,39 @@ def HLO_ComplexOp: HLO_Op<"complex",
|
||||
}
|
||||
|
||||
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp {
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp {
|
||||
}
|
||||
|
||||
def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp {
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp {
|
||||
}
|
||||
|
||||
def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp {
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp {
|
||||
}
|
||||
|
||||
def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp {
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||
|
||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
||||
|
||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||
|
||||
def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp;
|
||||
|
||||
def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp;
|
||||
|
||||
def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp {
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
@ -331,11 +325,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
|
||||
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
|
||||
HLO_BinaryElementwiseOp<mnemonic, [Commutative, NoSideEffect]> {
|
||||
HLO_BinaryElementwiseOp<
|
||||
mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let arguments = (ins
|
||||
HLO_PredOrIntTensor:$lhs,
|
||||
HLO_PredOrIntTensor:$rhs,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||
HLO_PredOrIntTensor:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
@ -617,23 +611,18 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||
}
|
||||
|
||||
def HLO_CompareOp: HLO_Op<"compare",
|
||||
[NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp {
|
||||
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
|
||||
BASE_HLO_CompareOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||
);
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value left, Value right, "
|
||||
"DenseIntElementsAttr broadcast_dimensions, "
|
||||
"StringAttr comparison_direction"
|
||||
>];
|
||||
let results = (outs HLO_PredTensor);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
|
||||
"StringAttr comparison_direction"
|
||||
>];
|
||||
}
|
||||
|
||||
|
@ -209,7 +209,6 @@ StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::CompareOp>(
|
||||
loc_, ty, GetValue(lhs), GetValue(rhs),
|
||||
/*broadcast_dimensions=*/mlir::DenseIntElementsAttr(),
|
||||
builder_.getStringAttr(ComparisonDirectionToString(direction)));
|
||||
return MakeXlaOp(op.getResult());
|
||||
}
|
||||
|
@ -0,0 +1,334 @@
|
||||
// Note that binary elementwise tests are run with chlo legalization enabled
|
||||
// (unlike the rest), since this is the primary use case for such ops and
|
||||
// verification of shapes and broadcasts is desired.
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op legalizations.
|
||||
// Most of these expand from the same pattern. Full semantics are
|
||||
// verified for tf.Add and pattern application only for the rest.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32>
|
||||
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %1: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_add
|
||||
// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check
|
||||
// patterns unambiguous and more interesting (once broadcastable trait is
|
||||
// fixed upstream).
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0: tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_multi_dim_add
|
||||
// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream
|
||||
// broadcastable bug is fixed (helps make the CHECK matching unambiguous)
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0: tensor<4x4x4x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_dynamic
|
||||
func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @div
|
||||
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shift_left
|
||||
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @div_unranked
|
||||
func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK: tf.Div
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maximum
|
||||
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @minimum
|
||||
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @real_div
|
||||
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub
|
||||
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shift_right
|
||||
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shift_right_unsigned
|
||||
func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> {
|
||||
// CHECK: tf.RightShift
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_shift_right_unsigned
|
||||
func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> {
|
||||
// CHECK: tf.RightShift
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8>
|
||||
return %0 : tensor<2x4xui8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @and
|
||||
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: xla_hlo.and
|
||||
%0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @and_unranked
|
||||
func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> {
|
||||
// CHECK: tf.LogicalAnd
|
||||
%0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @or
|
||||
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
%0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_or
|
||||
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_and
|
||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: xla_hlo.and
|
||||
%0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pow
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-NEXT: xla_hlo.power
|
||||
%0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0: tensor<2xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Equality op legalizations.
|
||||
// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are
|
||||
// verified for tf.Equal and pattern application only for tf.NotEqual
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @equal
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_dynamic
|
||||
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_broadcast
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error
|
||||
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_incompatible_shape_broadcastable
|
||||
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_incompatible_shape_dynamic
|
||||
func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
|
||||
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<?xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic
|
||||
func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
|
||||
// CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<?xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_unranked
|
||||
func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> {
|
||||
// CHECK: "tf.Equal"
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @notequal
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
|
||||
%0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compare op legalizations.
|
||||
// These expand from the same pattern. Full semantics are checked for
|
||||
// tf.Greater. Others just check that the pattern applied.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @greater
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_greater
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_dynamic
|
||||
func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> {
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_uranked
|
||||
func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
|
||||
// CHECK: "tf.Greater"
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_equal
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
|
||||
%0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less_equal
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
|
||||
%0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
||||
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
|
||||
return %4 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// Broadcasting is not currently supported.
|
||||
// TODO(suderman):Future pass should take all broadcasted binary ops and convert
|
||||
// them to separate broadcast and binary op.
|
||||
// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
|
||||
func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {
|
||||
name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {
|
||||
name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
// CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {
|
||||
name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
// CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {
|
||||
name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
// CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) {
|
||||
broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
(tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4x4xf32>
|
||||
return %4 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s
|
||||
// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_broadcast
|
||||
func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_broadcast
|
||||
func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_unranked
|
||||
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_broadcast
|
||||
func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_unranked
|
||||
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div_broadcast
|
||||
func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2)
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]])
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div_unranked
|
||||
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
@ -1,225 +1,5 @@
|
||||
// RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @addBroadcastRhs
|
||||
func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addBroadcastLhs
|
||||
func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addBroadcastEqual
|
||||
func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
|
||||
return %0 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addBroadcastMultidimension
|
||||
func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x1x4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
|
||||
return %0 : tensor<1x1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addBroadcastBothArgs
|
||||
func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32>
|
||||
return %0 : tensor<3x2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addBroadcastScalar
|
||||
func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addUnranked
|
||||
func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @atan2BroadcastRhs
|
||||
func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @divBroadcastRhs
|
||||
func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @maxBroadcastRhs
|
||||
func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @minBroadcastRhs
|
||||
func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mulBroadcastRhs
|
||||
func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @powBroadcastRhs
|
||||
func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remainderBroadcastRhs
|
||||
func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftLeftBroadcastRhs
|
||||
func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs
|
||||
func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftRightLogicalBroadcastRhs
|
||||
func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @subBroadcastRhs
|
||||
func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %arg0, %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @andBroadcastRhs
|
||||
func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @orBroadcastRhs
|
||||
func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @xorBroadcastRhs
|
||||
func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %arg0, %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @clampBroadcast
|
||||
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
|
||||
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
|
||||
@ -229,63 +9,3 @@ func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>
|
||||
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @compareBroadcastRhs
|
||||
func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> {
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%arg0, %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>
|
||||
return %0 : tensor<1x4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dynamicCompareBroadcastRhs
|
||||
func @dynamicCompareBroadcastRhs(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xi1> {
|
||||
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
// CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32>
|
||||
// CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index
|
||||
// CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dynamicBroadcastAdd
|
||||
func @dynamicBroadcastAdd(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
// CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32>
|
||||
// CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index
|
||||
// CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dynamicBroadcastAddScalar
|
||||
func @dynamicBroadcastAddScalar(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex>
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
|
||||
@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
||||
// Same rank degenerate broadcast
|
||||
// CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0)
|
||||
// CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
|
||||
// CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
|
||||
// CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1)
|
||||
// CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
|
||||
// Broadcast up rank
|
||||
// CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
|
||||
// CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
|
||||
// CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
|
||||
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
|
||||
// Broadcast up rank + degenerate broadcast
|
||||
// CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
|
||||
// CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
|
||||
// CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
|
||||
// CHECK: ROOT
|
||||
// CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
|
||||
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
return %2 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
HloModule main
|
||||
|
||||
@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
// This test is more thorough than those of the the other binary ops to test
|
||||
// their shared functionality.
|
||||
|
||||
// CHECK-LABEL: func @test_add
|
||||
%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
%Arg_3.4 = f32[] parameter(3)
|
||||
|
||||
// Add two tensors
|
||||
// CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"}
|
||||
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
|
||||
// Add two scalars
|
||||
// CHECK-NEXT: xla_hlo.add %arg2, %arg3
|
||||
%add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4)
|
||||
|
||||
// Add a tensor and scalar
|
||||
// CHECK-NEXT: "xla_hlo.add"(%0, %1)
|
||||
ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_after_all
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token
|
||||
%test_after_all (token0: token[], token1: token[] ) -> token[] {
|
||||
@ -159,11 +136,11 @@ add {
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> {
|
||||
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] {
|
||||
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> {
|
||||
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] {
|
||||
%Arg_0.1 = f32[3] parameter(0)
|
||||
%Arg_1.2 = f32[3] parameter(1)
|
||||
%Arg_2.3 = f32[1] parameter(2)
|
||||
%Arg_2.3 = f32[3] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
|
||||
@ -172,7 +149,7 @@ add {
|
||||
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
|
||||
|
||||
// Requires broadcast of compatible tensors.
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
|
||||
}
|
||||
|
||||
@ -280,19 +257,19 @@ add {
|
||||
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf64> {
|
||||
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] {
|
||||
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> {
|
||||
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f64>
|
||||
%convert.4 = f64[] convert(f32[] %Arg_1.2)
|
||||
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.add"(%0, %1)
|
||||
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4)
|
||||
// CHECK-NEXT: xla_hlo.add %0, %1
|
||||
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
|
||||
|
@ -163,8 +163,7 @@ struct HloBinaryElementwiseAdaptor {
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
@ -183,9 +182,9 @@ struct HloCompareAdaptor {
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
/*broadcast_dimensions=*/nullptr, from_op.comparison_direction());
|
||||
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -67,8 +67,9 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF &) {}
|
||||
explicit LegalizeTF(bool allow_partial_conversion) {
|
||||
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) {
|
||||
allow_partial_conversion_ = allow_partial_conversion;
|
||||
legalize_chlo_ = legalize_chlo;
|
||||
}
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
@ -79,6 +80,11 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
*this, "allow-partial-conversion",
|
||||
llvm::cl::desc("Allow operations that can't be legalized."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> legalize_chlo_{
|
||||
*this, "legalize-chlo",
|
||||
llvm::cl::desc(
|
||||
"Also legalizes intermediate chlo ops to hlo (default true)"),
|
||||
llvm::cl::init(true)};
|
||||
};
|
||||
|
||||
/// Returns if the given TF data format string is the default format.
|
||||
@ -362,6 +368,154 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
|
||||
return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
|
||||
}
|
||||
|
||||
// Deprecated: This is maintained to aid in porting old code that is not yet
|
||||
// dynamic shape aware and uses broadcasting modes that CHLO does not support.
|
||||
// Gets the resulting type from a broadcast between two types for statically
|
||||
// shaped types. This is to be used for legacy lowerings that both use non
|
||||
// left-padded broadcasting and static shapes. Its use should not be permitted
|
||||
// in new code.
|
||||
// May return nullptr on invalid static broadcast dimensions.
|
||||
// ABSL_DEPRECATED()
|
||||
static RankedTensorType GetStaticBroadcastType(
|
||||
RankedTensorType x, RankedTensorType y,
|
||||
DenseIntElementsAttr broadcast_dimensions_attr) {
|
||||
auto element_type = x.getElementType();
|
||||
auto shape_x = x.getShape();
|
||||
auto shape_y = y.getShape();
|
||||
|
||||
if (shape_x.size() == shape_y.size()) {
|
||||
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
|
||||
for (int i = 0; i < shape_x.size(); i++) {
|
||||
auto x_val = shape_x[i];
|
||||
auto y_val = shape_y[i];
|
||||
out_shape[i] = std::max(x_val, y_val);
|
||||
}
|
||||
return RankedTensorType::get(out_shape, element_type);
|
||||
}
|
||||
|
||||
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
|
||||
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
|
||||
|
||||
llvm::SmallVector<int64_t, 4> broadcast_dimensions;
|
||||
// Explicit broadcast dimensions.
|
||||
for (const APInt &int_value : broadcast_dimensions_attr) {
|
||||
broadcast_dimensions.push_back(int_value.getSExtValue());
|
||||
}
|
||||
if (broadcast_dimensions.size() != shape_small.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
|
||||
shape_large.end());
|
||||
|
||||
// Update according to the broadcast dimensions.
|
||||
for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
|
||||
auto old_value = out_shape[index_pair.value()];
|
||||
auto new_value = shape_small[index_pair.index()];
|
||||
out_shape[index_pair.value()] = std::max(old_value, new_value);
|
||||
}
|
||||
return RankedTensorType::get(out_shape, element_type);
|
||||
}
|
||||
|
||||
// Deprecated: This is maintained to aid in porting old code that is not yet
|
||||
// dynamic shape aware and uses broadcasting modes that CHLO does not support.
|
||||
// Applies static binary broadcasting to a binary elementwise op.
|
||||
// This is a legacy helper to provide general broadcasting support in legacy,
|
||||
// static shaped code that relies on non-left-padded broadcasting semantics.
|
||||
template <typename BinaryOp>
|
||||
static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
|
||||
DenseIntElementsAttr broadcast_dims,
|
||||
OpBuilder &builder) {
|
||||
auto x_type = x.getType().cast<RankedTensorType>();
|
||||
auto y_type = y.getType().cast<RankedTensorType>();
|
||||
auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
|
||||
if (!result_type) {
|
||||
emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
|
||||
<< " with broadcast_dims = " << broadcast_dims;
|
||||
return nullptr;
|
||||
}
|
||||
auto larger_broadcast_dims =
|
||||
GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
|
||||
if (x_type.getRank() < y_type.getRank()) {
|
||||
if (x_type != result_type) {
|
||||
x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
|
||||
}
|
||||
if (y_type != result_type) {
|
||||
y = builder.create<BroadcastInDimOp>(loc, result_type, y,
|
||||
larger_broadcast_dims);
|
||||
}
|
||||
} else {
|
||||
if (x_type != result_type) {
|
||||
x = builder.create<BroadcastInDimOp>(loc, result_type, x,
|
||||
larger_broadcast_dims);
|
||||
}
|
||||
if (y_type != result_type) {
|
||||
y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
|
||||
}
|
||||
}
|
||||
return builder.create<BinaryOp>(loc, x, y);
|
||||
}
|
||||
|
||||
// Gets a 1D tensor type suitable for expressing extents of the given tensor
|
||||
// value type. If the value type is ranked, the result will be statically
|
||||
// shaped. Otherwise, it will have a dynamic dimension.
|
||||
static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
|
||||
Builder b(value_type.getContext());
|
||||
int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
|
||||
return RankedTensorType::get({dim}, b.getIndexType());
|
||||
}
|
||||
|
||||
// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value'
|
||||
// by assuming that the shape of the lower ranked is a broadcast compatible
|
||||
// prefix of the higher ranked.
|
||||
// Values must be RankedTensorType (this restriction derives from the
|
||||
// broadcast_dimensions attribute on DynamicBroadcastInDim).
|
||||
//
|
||||
// Example:
|
||||
// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the
|
||||
// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style
|
||||
// implicit broadcasting).
|
||||
static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value,
|
||||
Value lower_rank_value, OpBuilder &builder) {
|
||||
Value higher_rank_shape =
|
||||
builder.create<shape::ShapeOfOp>(loc, higher_rank_value);
|
||||
auto result_extents_type =
|
||||
GetExtentsTensorTypeFor(higher_rank_value.getType().cast<TensorType>());
|
||||
Value result_extents = builder.create<shape::ToExtentTensorOp>(
|
||||
loc, result_extents_type, higher_rank_shape);
|
||||
|
||||
auto lower_rank_type = lower_rank_value.getType().cast<RankedTensorType>();
|
||||
auto lower_rank = lower_rank_type.getRank();
|
||||
auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder);
|
||||
return builder.create<DynamicBroadcastInDimOp>(
|
||||
loc, higher_rank_value.getType(), lower_rank_value, result_extents,
|
||||
prefix_dims);
|
||||
}
|
||||
|
||||
// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
|
||||
// value (broadcast_from) along that feature dimension. This is a shortcut
|
||||
// for the cases where a 1D tensor must be broadcast along a specific feature
|
||||
// dimension, which can vary based on data layout, etc.
|
||||
//
|
||||
// The extent of `broadcast_from` dim0 must be equal to the extent of the
|
||||
// feature_dim of `broadcast_to`.
|
||||
//
|
||||
// Example:
|
||||
// [1x2x3x4], [2], 1 -> [1x2x3x4]
|
||||
// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
|
||||
// consistency. Possibly also rename for clarity.
|
||||
static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
|
||||
Value broadcast_from, int64_t feature_dim,
|
||||
OpBuilder &builder) {
|
||||
auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
|
||||
auto to_type = broadcast_to.getType().cast<RankedTensorType>();
|
||||
auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
|
||||
auto result_extents_type = GetExtentsTensorTypeFor(to_type);
|
||||
auto result_extents = builder.create<shape::ToExtentTensorOp>(
|
||||
loc, result_extents_type, result_shape);
|
||||
return builder.create<DynamicBroadcastInDimOp>(
|
||||
loc, to_type, broadcast_from, result_extents, broadcast_dims);
|
||||
}
|
||||
|
||||
// Creates a batch dot using xla_hlo::DotGeneralOp.
|
||||
Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
|
||||
bool transpose_rhs, int64_t num_batch_dims,
|
||||
@ -407,8 +561,7 @@ static void BuildReduceBody(Type element_type, Region *body,
|
||||
|
||||
Location loc = body->getLoc();
|
||||
auto reducer =
|
||||
builder->create<Op>(loc, block->getArgument(0), block->getArgument(1),
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
|
||||
builder->create<ReturnOp>(loc, reducer.getResult());
|
||||
}
|
||||
|
||||
@ -508,8 +661,7 @@ static void CreateWhile32(Location loc, int num_iterations,
|
||||
loc, builder->getI32IntegerAttr(num_iterations));
|
||||
StringAttr compare_direction = StringAttr::get("LT", builder->getContext());
|
||||
Value compare = builder->create<xla_hlo::CompareOp>(
|
||||
loc, loop_iv, upper_limit,
|
||||
/*broadcast_dimensions=*/nullptr, compare_direction);
|
||||
loc, loop_iv, upper_limit, compare_direction);
|
||||
|
||||
builder->create<xla_hlo::ReturnOp>(loc, compare);
|
||||
}
|
||||
@ -539,9 +691,9 @@ static void CreateWhile32(Location loc, int num_iterations,
|
||||
// Increment the loop induction variable by one.
|
||||
auto one =
|
||||
builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
|
||||
auto no_broadcast_dims = GetI64ElementsAttr({}, builder);
|
||||
auto plus_one = builder->create<xla_hlo::AddOp>(loc, old_values[0], one,
|
||||
no_broadcast_dims);
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
|
||||
auto plus_one = builder->create<xla_chlo::BroadcastAddOp>(
|
||||
loc, old_values[0], one, scalar_broadcast_dims);
|
||||
// Prepend with the updated loop induction variable.
|
||||
new_values.insert(new_values.begin(), plus_one);
|
||||
|
||||
@ -566,21 +718,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
||||
GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bias op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
|
||||
// Requires input to have ranked tensor.
|
||||
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
|
||||
StringAttr format,
|
||||
Value input) {
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
size_t featureDim = GetFeatureDimension(format, inputType);
|
||||
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(type, featureDim);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatMul op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -743,8 +880,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type,
|
||||
StringAttr compare_direction =
|
||||
StringAttr::get(direction, builder->getContext());
|
||||
Value compare = builder->create<CompareOp>(
|
||||
loc, block->getArgument(0), block->getArgument(2),
|
||||
/*broadcast_dimensions=*/nullptr, compare_direction);
|
||||
loc, block->getArgument(0), block->getArgument(2), compare_direction);
|
||||
|
||||
Value selected_input = builder->create<SelectOp>(
|
||||
loc, input_type, compare, block->getArgument(0), block->getArgument(2));
|
||||
@ -860,8 +996,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
|
||||
StringAttr compare_direction =
|
||||
StringAttr::get(direction, builder->getContext());
|
||||
Value compare = builder->create<xla_hlo::CompareOp>(
|
||||
loc, block->getArgument(0), block->getArgument(1),
|
||||
/*broadcast_dimensions=*/nullptr, compare_direction);
|
||||
loc, block->getArgument(0), block->getArgument(1), compare_direction);
|
||||
|
||||
builder->create<xla_hlo::ReturnOp>(loc, compare);
|
||||
}
|
||||
@ -900,6 +1035,27 @@ NamedAttribute GetConvDimensionNumbersAttr(
|
||||
feature_dim, spatial_dims, builder->getContext()));
|
||||
}
|
||||
|
||||
// Converts a TF::BiasAddOp to HLO.
|
||||
// This differs from a normal TF::AddOp with respect to how the data_format
|
||||
// is handled, which can optionally require a general broadcast of the
|
||||
// 'bias' term in a way that is not compatible with the standard left-padded
|
||||
// broadcast semantics (i.e. NCHW will broadcast into dimension 1).
|
||||
// The correct 'bias' broadcast will be synthesized manually.
|
||||
class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TF::BiasAddOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto feature_dim = GetFeatureDimension(
|
||||
op.data_formatAttr(), op.value().getType().cast<RankedTensorType>());
|
||||
auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
|
||||
feature_dim, rewriter);
|
||||
rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts the TensorFlow conv op in template to the generic HLO conv op by
|
||||
// converting TensorFlow op attributes to HLO op attributes.
|
||||
//
|
||||
@ -1161,7 +1317,6 @@ class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
Value compare = rewriter.create<CompareOp>(
|
||||
op.getLoc(), iota0, iota1,
|
||||
/*broadcast_dimensions=*/nullptr,
|
||||
StringAttr::get("EQ", rewriter.getContext()));
|
||||
Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
|
||||
0, &rewriter);
|
||||
@ -1274,33 +1429,35 @@ class ConvertFusedBatchNormGradBase
|
||||
non_feature_dims.push_back(i);
|
||||
}
|
||||
auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
|
||||
auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter);
|
||||
auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
|
||||
// scratch1 = rsqrt(var + epsilon)
|
||||
RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
|
||||
auto epsilon = rewriter.create<ConstOp>(
|
||||
loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
|
||||
auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(),
|
||||
no_broadcast_dims);
|
||||
auto add_op = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
loc, var, epsilon.getResult(), scalar_broadcast_dims);
|
||||
|
||||
Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
|
||||
|
||||
// scratch2 = sum(y_backprop * (x - mean))
|
||||
auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims);
|
||||
auto weighted_grad =
|
||||
rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims);
|
||||
auto sub_op = rewriter.create<xla_hlo::SubOp>(
|
||||
loc, act,
|
||||
Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
|
||||
auto weighted_grad = rewriter.create<xla_hlo::MulOp>(loc, grad, sub_op);
|
||||
Value scratch2 =
|
||||
ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
|
||||
|
||||
// x_backprop = y_backprop * (scale * scratch1)
|
||||
auto scaled_grad =
|
||||
rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims);
|
||||
x_backprop =
|
||||
rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims);
|
||||
rewriter.create<xla_hlo::MulOp>(loc, op.scale(), scratch1);
|
||||
x_backprop = rewriter.create<xla_hlo::MulOp>(
|
||||
loc, grad,
|
||||
Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
|
||||
rewriter));
|
||||
|
||||
// scale_backprop = scratch2 * scratch1
|
||||
scale_backprop =
|
||||
rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims);
|
||||
scale_backprop = rewriter.create<xla_hlo::MulOp>(loc, scratch1, scratch2);
|
||||
|
||||
// offset_backprop = sum(y_backprop)
|
||||
offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
|
||||
@ -1396,7 +1553,7 @@ class ConvertFusedBatchNormV3Op
|
||||
auto factor_const_op = rewriter.create<xla_hlo::ConstOp>(
|
||||
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
|
||||
|
||||
Value corrected_variance = rewriter.create<xla_hlo::MulOp>(
|
||||
Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), batch_variance.getType(), batch_variance,
|
||||
factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
|
||||
@ -1416,24 +1573,26 @@ class ConvertFusedBatchNormV3Op
|
||||
rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
|
||||
|
||||
// new_running_mean = alpha * old_mean + beta * batch_mean.
|
||||
auto alpha_mul_old_mean = rewriter.create<MulOp>(
|
||||
auto alpha_mul_old_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), op.mean().getType(), alpha, op.mean(),
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
auto beta_mul_batch_mean = rewriter.create<MulOp>(
|
||||
auto beta_mul_batch_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), batch_mean.getType(), beta, batch_mean,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
batch_mean = rewriter.create<AddOp>(
|
||||
batch_mean = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
|
||||
// new_running_variance = alpha * old_variance + beta * batch_variance.
|
||||
auto alpha_mul_old_variance = rewriter.create<MulOp>(
|
||||
auto alpha_mul_old_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), op.variance().getType(), alpha, op.variance(),
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
auto beta_mul_batch_variance = rewriter.create<MulOp>(
|
||||
op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
corrected_variance = rewriter.create<AddOp>(
|
||||
auto beta_mul_batch_variance =
|
||||
rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), corrected_variance.getType(), beta,
|
||||
corrected_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
corrected_variance = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
}
|
||||
@ -1586,10 +1745,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
||||
// Divide by the number of elements in the window.
|
||||
Value divisor =
|
||||
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
|
||||
auto batch_dims =
|
||||
GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter);
|
||||
Value result = rewriter.create<DivOp>(op.getLoc(), result_type, reduce,
|
||||
divisor, batch_dims);
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
Value result = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
|
||||
|
||||
// Convert back if we enlarged the element type's bitwidth.
|
||||
if (input_element_type != sum_element_type)
|
||||
@ -1759,16 +1917,14 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
|
||||
op.getLoc(), type, scalar_one,
|
||||
GetI64ElementsAttr(type.getShape(), &rewriter));
|
||||
|
||||
auto scaled_input = rewriter.create<MulOp>(
|
||||
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
|
||||
auto scaled_input =
|
||||
rewriter.create<xla_hlo::MulOp>(op.getLoc(), operand, constant_ones);
|
||||
auto tanh_op =
|
||||
rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
|
||||
auto mul_op =
|
||||
rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
rewriter.create<xla_hlo::MulOp>(op.getLoc(), tanh_op, constant_ones);
|
||||
auto add_op =
|
||||
rewriter.create<AddOp>(op.getLoc(), mul_op, constant_ones,
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
rewriter.create<xla_hlo::AddOp>(op.getLoc(), mul_op, constant_ones);
|
||||
|
||||
rewriter.replaceOp(op, add_op.getResult());
|
||||
return success();
|
||||
@ -1807,20 +1963,18 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value logits = op.logits();
|
||||
|
||||
// Softmax converter requires ranked type because the XLA reduce ops used
|
||||
// while lowering requires dimensions attribute to reduce along.
|
||||
// Note that the input and output shape is equivalent, so we use 'logits'
|
||||
// and its type for shape calculations.
|
||||
Value logits = op.logits();
|
||||
RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type) return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int rank = type.getRank();
|
||||
|
||||
// Note that the TensorFlow Softmax op verifies that the input rank is
|
||||
// greater than or equal to one so both of the following sequences are
|
||||
// valid.
|
||||
auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter);
|
||||
// greater than or equal to one so the following sequence is valid.
|
||||
auto reduce_dim = rewriter.create<TF::ConstOp>(
|
||||
loc, GetI64ElementsAttr({rank - 1}, &rewriter));
|
||||
|
||||
@ -1833,8 +1987,10 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
|
||||
auto max_logits =
|
||||
rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
|
||||
/*keep_dims=*/rewriter.getBoolAttr(false));
|
||||
auto shifted_logits =
|
||||
rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims);
|
||||
auto max_logits_broadcast =
|
||||
CommonPrefixBroadcast(loc, logits, max_logits, rewriter);
|
||||
auto shifted_logits = rewriter.create<xla_hlo::SubOp>(loc, type, logits,
|
||||
max_logits_broadcast);
|
||||
|
||||
// Exponentiate the inputs.
|
||||
Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
|
||||
@ -1847,9 +2003,12 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
|
||||
|
||||
if (use_log) {
|
||||
Value log = rewriter.create<LogOp>(loc, sum);
|
||||
rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims);
|
||||
auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::SubOp>(op, shifted_logits,
|
||||
log_broadcast);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims);
|
||||
auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op, exp, sum_broadcast);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -1896,7 +2055,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
auto dim = rewriter.create<GetDimensionSizeOp>(
|
||||
op.getLoc(), result_type, input,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
|
||||
size = rewriter.create<MulOp>(
|
||||
size = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), size->getResult(0), dim.getResult(),
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
}
|
||||
@ -2582,10 +2741,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
|
||||
|
||||
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<MulOp>(
|
||||
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), result_type, iota, op.delta(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
|
||||
rewriter.replaceOpWithNewOp<AddOp>(
|
||||
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
|
||||
op, result_type, scaled, op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
|
||||
return success();
|
||||
@ -2633,7 +2792,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
int64_t num = (*num_attr.begin()).getSExtValue();
|
||||
|
||||
// Calculate the scaling that needs to be applied to the iota.
|
||||
auto step_numerator = rewriter.create<SubOp>(
|
||||
auto step_numerator = rewriter.create<xla_chlo::BroadcastSubOp>(
|
||||
op.getLoc(), op.start().getType(), op.stop(), op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
|
||||
Value step_denominator = rewriter.create<ConvertOp>(
|
||||
@ -2641,11 +2800,11 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
if (num > 1) {
|
||||
Value one = GetScalarConstOfType(result_type.getElementType(),
|
||||
op.getLoc(), 1, &rewriter);
|
||||
step_denominator = rewriter.create<SubOp>(
|
||||
step_denominator = rewriter.create<xla_chlo::BroadcastSubOp>(
|
||||
op.getLoc(), step_denominator.getType(), step_denominator, one,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
|
||||
}
|
||||
auto step = rewriter.create<DivOp>(
|
||||
auto step = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, step_numerator,
|
||||
step_denominator));
|
||||
@ -2653,10 +2812,10 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
// Scale the iota and add the offset.
|
||||
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<MulOp>(
|
||||
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), result_type, iota, step,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, iota, step));
|
||||
rewriter.replaceOpWithNewOp<AddOp>(
|
||||
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
|
||||
op, result_type, scaled, op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
|
||||
return success();
|
||||
@ -2732,8 +2891,8 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
|
||||
auto divisor = GetScalarConstOfType(reduce_element_type, loc,
|
||||
divisor_count, &rewriter);
|
||||
auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
result = rewriter.create<DivOp>(loc, result, divisor.getResult(),
|
||||
broadcast_dims);
|
||||
result = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
loc, result, divisor.getResult(), broadcast_dims);
|
||||
}
|
||||
|
||||
result = rewriter.create<ConvertOp>(loc, result, element_type);
|
||||
@ -3118,7 +3277,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
|
||||
|
||||
auto reducer = rewriter.create<CompareOp>(
|
||||
loc, block->getArgument(0), block->getArgument(1),
|
||||
/*broadcast_dimensions=*/nullptr,
|
||||
StringAttr::get("GE", rewriter.getContext()));
|
||||
rewriter.create<ReturnOp>(loc, reducer.getResult());
|
||||
}
|
||||
@ -3544,13 +3702,20 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
|
||||
output_dims.insert(output_dims.begin() + axis, depth);
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// The iota result is the effective output shape of the computation,
|
||||
// and indices must be broadcast into it. At this point, this computation
|
||||
// would need to be reworked quite a bit to support dynamic shapes, so
|
||||
// just using static broadcasting.
|
||||
auto index_type = RankedTensorType::get(output_dims, element_type);
|
||||
Value compare = rewriter.create<CompareOp>(
|
||||
loc, op.indices(),
|
||||
rewriter.create<IotaOp>(
|
||||
loc, index_type,
|
||||
IntegerAttr::get(rewriter.getIntegerType(64), axis)),
|
||||
GetI64ElementsAttr(broadcast_dims, &rewriter),
|
||||
auto iota = rewriter.create<IotaOp>(
|
||||
loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
|
||||
auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
|
||||
loc, index_type, op.indices(),
|
||||
GetI64ElementsAttr(broadcast_dims, &rewriter));
|
||||
|
||||
Value compare = rewriter.create<xla_hlo::CompareOp>(
|
||||
loc, broadcast_indices, iota,
|
||||
StringAttr::get("EQ", rewriter.getContext()));
|
||||
Value on_value = rewriter.create<BroadcastOp>(
|
||||
loc, op.getType(), op.on_value(),
|
||||
@ -4396,7 +4561,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
Value compare = rewriter.create<CompareOp>(
|
||||
op.getLoc(), iota0, iota1,
|
||||
/*broadcast_dimensions=*/nullptr,
|
||||
StringAttr::get("EQ", rewriter.getContext()));
|
||||
Value identity_matrix =
|
||||
rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
|
||||
@ -4430,8 +4594,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
batch_dims.size(), precision_config, &rewriter);
|
||||
a_update = BatchDot(op.getLoc(), y, false, a_update, false,
|
||||
batch_dims.size(), precision_config, &rewriter);
|
||||
a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
|
||||
a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
|
||||
&rewriter);
|
||||
|
||||
@ -4442,8 +4605,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
batch_dims.size(), precision_config, &rewriter);
|
||||
q_update = BatchDot(op.getLoc(), q_update, false, y, true,
|
||||
batch_dims.size(), precision_config, &rewriter);
|
||||
q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
|
||||
q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
|
||||
}
|
||||
// full_matrices is false when only a partial result in needed. Slice to the
|
||||
@ -4505,34 +4667,31 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
Value iota = builder->create<IotaOp>(
|
||||
loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(0));
|
||||
Value gtk = builder->create<CompareOp>(
|
||||
Value gtk = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota, k, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("GT", builder->getContext()));
|
||||
gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
|
||||
Value x_after_k = builder->create<MulOp>(
|
||||
Value x_after_k = builder->create<xla_chlo::BroadcastMulOp>(
|
||||
loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
|
||||
Value x_after_k_sq = builder->create<MulOp>(
|
||||
loc, x_after_k, x_after_k, /*broadcast_dimensions=*/nullptr);
|
||||
Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
auto sigma = builder->create<ReduceOp>(
|
||||
loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
|
||||
BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
|
||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||
Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
|
||||
Value mu = builder->create<SqrtOp>(
|
||||
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0),
|
||||
/*broadcast_dimensions=*/nullptr));
|
||||
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
|
||||
|
||||
Value sigma_is_zero = builder->create<CompareOp>(
|
||||
Value sigma_is_zero = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
Value alpha_is_negative = builder->create<CompareOp>(
|
||||
Value alpha_is_negative = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, alpha, zero, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("LT", builder->getContext()));
|
||||
auto batch_size_one = builder->create<BroadcastOp>(
|
||||
loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
|
||||
Value signed_mu = builder->create<MulOp>(
|
||||
Value signed_mu = builder->create<xla_chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
|
||||
batch_size_one,
|
||||
@ -4541,21 +4700,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
*beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
|
||||
alpha, signed_mu);
|
||||
*tau = builder->create<DivOp>(
|
||||
loc,
|
||||
builder->create<SubOp>(loc, *beta, alpha,
|
||||
/*broadcast_dimensions=*/nullptr),
|
||||
*beta,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
|
||||
Value zero_tau = builder->create<BroadcastOp>(
|
||||
loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
|
||||
*tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
|
||||
zero_tau, *tau);
|
||||
Value divisor = builder->create<SubOp>(loc, alpha, *beta,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
Value divisor = builder->create<SubOp>(loc, alpha, *beta);
|
||||
divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
|
||||
batch_size_one, divisor);
|
||||
|
||||
Value eqk = builder->create<CompareOp>(
|
||||
Value eqk = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota, k, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
|
||||
@ -4568,10 +4722,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
|
||||
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
|
||||
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
|
||||
*v = builder->create<AddOp>(
|
||||
// Note that the add performs a degenerate broadcast.
|
||||
*v = builder->create<xla_chlo::BroadcastAddOp>(
|
||||
loc, e_k,
|
||||
builder->create<DivOp>(loc, x_after_k, divisor,
|
||||
GetI64ElementsAttr(batch_dim_ids, builder)),
|
||||
StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
|
||||
GetI64ElementsAttr(batch_dim_ids, builder),
|
||||
*builder),
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
}
|
||||
|
||||
@ -4645,10 +4801,10 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
precision, builder);
|
||||
vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
|
||||
precision, builder);
|
||||
auto tau_x_vva = builder->create<MulOp>(
|
||||
loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder));
|
||||
a = builder->create<SubOp>(loc, a, tau_x_vva,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
auto tau_x_vva = StaticBinaryBroadcast<xla_hlo::MulOp>(
|
||||
loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
|
||||
*builder);
|
||||
a = builder->create<SubOp>(loc, a, tau_x_vva);
|
||||
|
||||
// It is more precise to populate column 'k' explicitly, rather than
|
||||
// computing it implicitly by applying the Householder transformation.
|
||||
@ -4657,12 +4813,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
auto iota = builder->create<IotaOp>(
|
||||
loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(0));
|
||||
Value predecessor_mask = builder->create<CompareOp>(
|
||||
Value predecessor_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("LT", builder->getContext()));
|
||||
predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
|
||||
a_type.getElementType());
|
||||
Value mask = builder->create<CompareOp>(
|
||||
Value mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
|
||||
@ -4674,14 +4830,14 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
mask,
|
||||
GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
|
||||
builder));
|
||||
Value predecessor_masked_x = builder->create<MulOp>(
|
||||
Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
|
||||
loc, x, predecessor_mask,
|
||||
GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder));
|
||||
Value masked_beta = builder->create<MulOp>(
|
||||
loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder));
|
||||
GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
|
||||
Value masked_beta = StaticBinaryBroadcast<MulOp>(
|
||||
loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
|
||||
*builder);
|
||||
Value new_x =
|
||||
builder->create<AddOp>(loc, predecessor_masked_x, masked_beta,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
|
||||
// Update a[:,j]
|
||||
llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
|
||||
std::iota(dim_ids.begin(), dim_ids.end(), 0);
|
||||
@ -4692,7 +4848,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc,
|
||||
RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(minor_dim + 1));
|
||||
Value xa_mask = builder->create<CompareOp>(
|
||||
Value xa_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
|
||||
@ -4708,11 +4864,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
builder));
|
||||
auto vs_update = builder->create<SelectOp>(
|
||||
loc, vs.getType(), xa_mask,
|
||||
builder->create<AddOp>(
|
||||
loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder)),
|
||||
StaticBinaryBroadcast<AddOp>(
|
||||
loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
|
||||
*builder),
|
||||
vs_zeros);
|
||||
vs = builder->create<AddOp>(loc, vs, vs_update,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
vs = builder->create<AddOp>(loc, vs, vs_update);
|
||||
|
||||
// taus[j] = tau
|
||||
llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
|
||||
@ -4729,17 +4885,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc, taus.getType(), taus_zeros,
|
||||
GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
|
||||
builder));
|
||||
Value taus_mask = builder->create<CompareOp>(
|
||||
Value taus_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota_n, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
auto taus_update = builder->create<SelectOp>(
|
||||
loc, taus.getType(), taus_mask,
|
||||
builder->create<AddOp>(
|
||||
StaticBinaryBroadcast<AddOp>(
|
||||
loc, taus_zeros, tau,
|
||||
GetI64ElementsAttr(tau_broadcast_dims, builder)),
|
||||
GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
|
||||
taus_zeros);
|
||||
taus = builder->create<AddOp>(loc, taus, taus_update,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
taus = builder->create<AddOp>(loc, taus, taus_update);
|
||||
new_values->assign({a, vs, taus});
|
||||
};
|
||||
|
||||
@ -4796,8 +4951,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
j = builder->create<AddOp>(
|
||||
loc, j,
|
||||
GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
|
||||
builder),
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
builder));
|
||||
// vs has shape [..., m, 1]
|
||||
auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
|
||||
// beta has shape [..., 1]
|
||||
@ -4816,7 +4970,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc, vs.getType(), zero,
|
||||
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
|
||||
builder));
|
||||
auto compare = builder->create<CompareOp>(
|
||||
auto compare = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("GE", builder->getContext()));
|
||||
auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
|
||||
@ -4831,13 +4985,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
|
||||
// z = -beta * (v + wyv)
|
||||
auto neg_beta = builder->create<NegOp>(loc, beta);
|
||||
auto v_wyv = builder->create<AddOp>(loc, v, wyv,
|
||||
/*broadcast_dimensions=*/nullptr);
|
||||
auto v_wyv = builder->create<AddOp>(loc, v, wyv);
|
||||
auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
|
||||
beta_broadcast_dims.push_back(n_index);
|
||||
auto z = builder->create<MulOp>(
|
||||
auto z = StaticBinaryBroadcast<MulOp>(
|
||||
loc, neg_beta, v_wyv,
|
||||
GetI64ElementsAttr(beta_broadcast_dims, builder));
|
||||
GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
|
||||
|
||||
w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
|
||||
new_values->assign({w, vs, taus});
|
||||
@ -4855,8 +5008,9 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
auto neg_beta = rewriter->create<NegOp>(loc, beta);
|
||||
auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
|
||||
beta_broadcast_dims.push_back(n_index);
|
||||
auto bv = rewriter->create<MulOp>(
|
||||
loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter));
|
||||
auto bv = StaticBinaryBroadcast<MulOp>(
|
||||
loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
|
||||
*rewriter);
|
||||
w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
|
||||
|
||||
SmallVector<Value, 4> while_output;
|
||||
@ -4912,7 +5066,8 @@ void EmitLegalizationErrors(Operation *op,
|
||||
|
||||
// Performs the lowering to XLA dialect.
|
||||
void LegalizeTF::runOnFunction() {
|
||||
if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
|
||||
if (failed(
|
||||
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
@ -4923,7 +5078,8 @@ static PassRegistration<LegalizeTF> pass(
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
bool legalize_chlo) {
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
// Add lowering patterns to the list.
|
||||
@ -4936,19 +5092,19 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
patterns.insert<
|
||||
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
||||
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp,
|
||||
ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp,
|
||||
ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp,
|
||||
ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp,
|
||||
ConvertEinsumOp, ConvertFusedBatchNormGradOp,
|
||||
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
||||
ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
|
||||
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
|
||||
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
||||
ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
|
||||
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
|
||||
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
|
||||
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
||||
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
|
||||
ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
|
||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
|
||||
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
|
||||
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||
ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
@ -4959,10 +5115,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
|
||||
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
||||
// CHLO first.
|
||||
xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
|
||||
if (legalize_chlo) {
|
||||
xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
|
||||
}
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalDialect<xla_chlo::XlaHloClientDialect>();
|
||||
if (legalize_chlo) {
|
||||
target.addIllegalDialect<xla_chlo::XlaHloClientDialect>();
|
||||
} else {
|
||||
target.addLegalDialect<xla_chlo::XlaHloClientDialect>();
|
||||
}
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
@ -4988,8 +5150,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
bool allow_partial_conversion) {
|
||||
return std::make_unique<LegalizeTF>(allow_partial_conversion);
|
||||
bool allow_partial_conversion, bool legalize_chlo) {
|
||||
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
|
||||
}
|
||||
|
||||
} // end namespace xla_hlo
|
||||
|
@ -73,21 +73,6 @@ def : Pattern<
|
||||
// HLO and XLA doesn't support Assertions.
|
||||
def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bias op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def BiasAddFeatureDimension : NativeCodeCall<
|
||||
"getBiasFeatureDimension($_builder, $0, $1)">;
|
||||
|
||||
// $input needs to be a ranked tensor to identify index of the feature
|
||||
// dimension depending on the data_format 'NHWC' or 'NCHW'.
|
||||
// TODO(laurenzo): This should be converted to do explicit broadcasting since
|
||||
// it can generate broadcast dimensions that are not compatible with the simple
|
||||
// xla_chlo.add broadcast_dims.
|
||||
def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format),
|
||||
(HLO_AddOp $input, $bias,
|
||||
(BiasAddFeatureDimension $data_format, $input))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -114,7 +99,8 @@ foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],
|
||||
|
||||
def LowerRightShiftSigned :
|
||||
Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
|
||||
(BinBroadcastDimensions $l, $r)),
|
||||
[(SignedIntTensor $r)]>;
|
||||
|
||||
// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op
|
||||
@ -126,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
|
||||
//
|
||||
// return floor(div(x, y))
|
||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||
(HLO_FloorOp
|
||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||
[(IEEEFloatTensor $l)]>;
|
||||
|
||||
// Performs a substitution of FloorDir for integer tensors, which required
|
||||
// Performs a substitution of FloorDiv for integer tensors, which required
|
||||
// additional correction for a negative numerator / denominator. Equivalent
|
||||
// pseudocode is shown below:
|
||||
//
|
||||
@ -150,16 +137,16 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
|
||||
(HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)),
|
||||
(HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
|
||||
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ),
|
||||
(HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLO_DivOp
|
||||
(HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l),
|
||||
(HLO_SubOp (HLO_AbsOp $r),
|
||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLOClient_BroadcastDivOp
|
||||
(HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
|
||||
(HLOClient_BroadcastSubOp (HLO_AbsOp $r),
|
||||
(HLO_ConstOp (ConstantSplat<"1"> $r)),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(BinBroadcastDimensions $l, $r))),
|
||||
@ -175,20 +162,20 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
|
||||
(HLO_SelectOp
|
||||
(HLO_AndOp
|
||||
(HLO_CompareOp
|
||||
(HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLOClient_BroadcastAndOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
(HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)),
|
||||
(BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE),
|
||||
(HLO_CompareOp
|
||||
(HLO_CompareOp:$r_cmp $r,
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLOClient_BroadcastCompareOp:$r_cmp $r,
|
||||
(HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
|
||||
(HLO_CompareOp:$rem_cmp $rem, $r_zeros,
|
||||
(HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros,
|
||||
(BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT),
|
||||
(BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_AddOp $r,
|
||||
(HLOClient_BroadcastAddOp $r,
|
||||
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -406,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
|
||||
(HLO_SelectOp:$num_lower_or_m
|
||||
(HLO_CompareOp
|
||||
$num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
),
|
||||
$m_dim,
|
||||
$num_lower
|
||||
),
|
||||
(HLO_SelectOp:$num_upper_or_n
|
||||
(HLO_CompareOp
|
||||
$num_upper, $zero,
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT
|
||||
$num_upper, $zero, HLO_COMPARISON_DIRECTION_LT
|
||||
),
|
||||
$n_dim,
|
||||
$num_upper
|
||||
),
|
||||
(HLO_SelectOp
|
||||
(HLO_AndOp
|
||||
(HLO_CompareOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLO_NegOp
|
||||
(createConvertOp $op, $num_lower_or_m, $input)
|
||||
),
|
||||
(HLO_SubOp:$offset
|
||||
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input),
|
||||
(NullDenseIntElementsAttr)
|
||||
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input)
|
||||
),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
|
||||
),
|
||||
(HLO_CompareOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
$offset,
|
||||
(createConvertOp
|
||||
$op, $num_upper_or_n, $input
|
||||
),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
|
||||
),
|
||||
(BinBroadcastDimensions $offset, $input)
|
||||
)
|
||||
),
|
||||
$input,
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $input))
|
||||
@ -462,8 +446,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value),
|
||||
// TODO(hinsu): Lower unsigned and quantized types after supporting
|
||||
// them in GetScalarOfType.
|
||||
def : Pat<(TF_ReluOp AnyRankedTensor:$input),
|
||||
(HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
|
||||
(BinBroadcastDimensions $zero, $input)),
|
||||
(HLOClient_BroadcastMaxOp
|
||||
(HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
|
||||
(BinBroadcastDimensions $zero, $input)),
|
||||
[(TF_SintOrFpTensor $input)]>;
|
||||
|
||||
// TODO(hinsu): Lower unsigned and quantized types after supporting
|
||||
@ -485,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input),
|
||||
// to create splat tensor of dynamic shape in HLO.
|
||||
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp $features,
|
||||
(HLOClient_BroadcastCompareOp $features,
|
||||
(HLO_ConstOp (GetScalarOfType<0> $features)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
|
||||
$gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
|
||||
@ -598,7 +583,6 @@ def : Pat<(TF_SignOp $x),
|
||||
(HLO_CompareOp
|
||||
$x,
|
||||
$x,
|
||||
(NullDenseIntElementsAttr),
|
||||
HLO_COMPARISON_DIRECTION_NE
|
||||
),
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $x)),
|
||||
@ -641,8 +625,6 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_MulOp
|
||||
(HLO_MulOp $r, $l, (NullDenseIntElementsAttr)),
|
||||
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_MulOp $r, $l),
|
||||
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l)),
|
||||
[(IEEEFloatTensor $l)]>;
|
||||
|
@ -36,47 +36,36 @@ def IsSameSizePred : CPred<
|
||||
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
|
||||
|
||||
|
||||
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
|
||||
(AndOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||
(AddFOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||
(SubFOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||
(MulFOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||
(DivFOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||
(RemFOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||
(AddIOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||
(SubIOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||
(MulIOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||
(SignedDivIOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r,
|
||||
IsNullAttr:$broadcast_dimensions),
|
||||
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||
(SignedRemIOp $l, $r),
|
||||
[(IsSameSizeConstraint $l, $r)]>;
|
||||
|
@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
// and imaginary components.
|
||||
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
|
||||
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
|
||||
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
HLO_ComplexTensor:$rhs),
|
||||
(HLO_ComplexOp
|
||||
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs),
|
||||
$broadcast_dimensions),
|
||||
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs),
|
||||
$broadcast_dimensions))>;
|
||||
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
|
||||
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
|
||||
|
||||
// Complex multiplication results in a cross product multiplication between the
|
||||
// real and imaginary components such that:
|
||||
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
|
||||
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
|
||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
||||
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
HLO_ComplexTensor:$rhs),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SubOp
|
||||
(HLO_MulOp
|
||||
(HLO_RealOp:$lhs_real $lhs),
|
||||
(HLO_RealOp:$rhs_real $rhs),
|
||||
$broadcast_dimensions),
|
||||
(HLO_RealOp:$rhs_real $rhs)),
|
||||
(HLO_MulOp
|
||||
(HLO_ImagOp:$lhs_imag $lhs),
|
||||
(HLO_ImagOp:$rhs_imag $rhs),
|
||||
$broadcast_dimensions),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_ImagOp:$rhs_imag $rhs))),
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions),
|
||||
(HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions),
|
||||
(NullDenseIntElementsAttr)))>;
|
||||
(HLO_MulOp $lhs_real, $rhs_imag),
|
||||
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
||||
|
||||
// Multiplication between a complex and real tensor can be distributed by
|
||||
// applying the real multiplicant to both the real and complex component.
|
||||
//
|
||||
// Note that the sourcep pattern is not legal according to the HLO dialect but
|
||||
// instead handle intermediates generated by other patterns.
|
||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
|
||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
||||
(HLO_ComplexOp
|
||||
(HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
|
||||
(HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
|
||||
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
|
||||
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
|
||||
|
||||
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
|
||||
(HLO_ComplexOp
|
||||
(HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions),
|
||||
(HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>;
|
||||
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
|
||||
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
|
||||
|
||||
|
||||
// Division is performed by normalizing the denominator by multiplying by the
|
||||
// conjugate of the rhs.
|
||||
// numerator = lhs * conj(rhs)
|
||||
// denominator = rhs * conj(rhs)
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
||||
(HLO_DivOp
|
||||
(HLO_MulOp:$num $lhs,
|
||||
(HLO_ComplexOp:$conj
|
||||
(HLO_RealOp $rhs),
|
||||
(HLO_NegOp (HLO_ImagOp $rhs))),
|
||||
$broadcast_dimensions),
|
||||
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)),
|
||||
(BinBroadcastDimensions $num, $den))>;
|
||||
(HLO_NegOp (HLO_ImagOp $rhs)))),
|
||||
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
|
||||
|
||||
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
||||
(HLO_ComplexOp
|
||||
(HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
|
||||
(HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
|
||||
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
|
||||
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
|
||||
|
||||
|
||||
// Absolute value is evaluated as:
|
||||
@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp (HLO_RealOp:$real $val), $real,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(NullDenseIntElementsAttr))),
|
||||
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
|
||||
|
||||
// Exponential can be lowered to an exponential on the real component and a
|
||||
@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
||||
(HLO_ExpOp (HLO_RealOp $val)),
|
||||
(HLO_ComplexOp
|
||||
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
||||
(HLO_SinOp $imag)),
|
||||
(NullDenseIntElementsAttr))>;
|
||||
(HLO_SinOp $imag)))>;
|
||||
|
@ -28,259 +28,6 @@ namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns a 1-d i64 elements attribute populated with numbers from start to
|
||||
// end, excluding.
|
||||
static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
|
||||
Builder *builder) {
|
||||
int size = end - start;
|
||||
|
||||
SmallVector<int64_t, 4> vals;
|
||||
vals.resize(size);
|
||||
std::iota(vals.begin(), vals.end(), start);
|
||||
|
||||
TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(ty, vals);
|
||||
}
|
||||
|
||||
// Helper function for OpRewritePattern classes to materialize broadcasts on
|
||||
// LHS and RHS arguments to a binary op.
|
||||
//
|
||||
// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful,
|
||||
// returns false otherwise.
|
||||
template <typename SrcOp>
|
||||
bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
|
||||
Value *out_lhs, Value *out_rhs) {
|
||||
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
|
||||
// replacing the original LHS and RHS args in the source op with the results
|
||||
// of the broadcasts.
|
||||
//
|
||||
// If the higher dimensional argument does not actually need the broadcast,
|
||||
// a canonicalization pass should be able to remove that op later.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) {
|
||||
// Unranked, can't determine at this point how to perform the broadcast.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Dynamic result shape, can't use BroadcastInDimOp.
|
||||
assert(op_ranked_type.hasStaticShape() &&
|
||||
"dynamic shape requires DynamicBroadcastInDim");
|
||||
|
||||
auto lhs_rank = lhs_ranked_type.getRank();
|
||||
auto rhs_rank = rhs_ranked_type.getRank();
|
||||
ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
|
||||
|
||||
// BroadcastInDimOp must have the same element type for operands and results,
|
||||
// so preserve the original output shape and the original input element type.
|
||||
// For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`:
|
||||
// broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
|
||||
if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) {
|
||||
auto type =
|
||||
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
|
||||
DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter);
|
||||
if (lhs_rank < rhs_rank) {
|
||||
attr = op.broadcast_dimensions().getValue();
|
||||
}
|
||||
|
||||
lhs =
|
||||
rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, lhs, attr);
|
||||
}
|
||||
|
||||
if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) {
|
||||
auto type =
|
||||
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
|
||||
DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter);
|
||||
if (rhs_rank < lhs_rank) {
|
||||
attr = op.broadcast_dimensions().getValue();
|
||||
}
|
||||
|
||||
rhs =
|
||||
rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, rhs, attr);
|
||||
}
|
||||
|
||||
*out_lhs = lhs;
|
||||
*out_rhs = rhs;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper template to generate code for computing the result shape of a
|
||||
// broadcasted operation. This ultimately should be subsumed by functions
|
||||
// from the shape dialect.
|
||||
// Assumes that large and small are the operand values of `op` and that they
|
||||
// have a ranked tensory type with rank(large) >= rank(small).
|
||||
template <typename SrcOp>
|
||||
std::vector<Value> ComputeBroadcastedShape(SrcOp op, Value small, Value large,
|
||||
PatternRewriter *rewriter) {
|
||||
auto loc = op.getLoc();
|
||||
auto larger_ranked_type = large.getType().cast<RankedTensorType>();
|
||||
auto output_rank = larger_ranked_type.getRank();
|
||||
|
||||
constexpr int kExpandShape = -1;
|
||||
|
||||
std::vector<Value> shape_values;
|
||||
shape_values.reserve(output_rank);
|
||||
std::vector<int> indexes(output_rank, kExpandShape);
|
||||
DenseIntElementsAttr broadcast_dimensions =
|
||||
op.broadcast_dimensions().getValue();
|
||||
// Compute a mapping from output dimensions to their corresponding input
|
||||
// dimensions in the smaller ranked operand.
|
||||
for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) {
|
||||
indexes.at(pair.value().getLimitedValue()) = pair.index();
|
||||
}
|
||||
|
||||
// Compute the broadcasted shape of the result using numpy style broadcasting
|
||||
// semantics. The result shape at a position is the shape of the larger
|
||||
// operand at that position if the no dimension of the smaller operand is
|
||||
// mapped to it.
|
||||
// If both operands contribute to an output dimension, their shape has to
|
||||
// either be the same in that dimension or it can be 1, in which case the
|
||||
// shape of the other operand is used.
|
||||
for (int i = 0; i < output_rank; ++i) {
|
||||
if (indexes[i] == kExpandShape) {
|
||||
// The smaller shape gets expanded to the larger one in this case.
|
||||
shape_values.push_back(rewriter->create<mlir::DimOp>(loc, large, i));
|
||||
continue;
|
||||
}
|
||||
// Compute the result shape depending on whether the rank of smaller is 1.
|
||||
// This does not check that the broadcast operation actualy is correct.
|
||||
// In particular, we do not check that both shapes are the same if the
|
||||
// smaller ranked shape is not 1.
|
||||
ConstantOp one = rewriter->create<mlir::ConstantOp>(
|
||||
loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1));
|
||||
DimOp lrg_dim = rewriter->create<mlir::DimOp>(loc, large, i);
|
||||
DimOp sml_dim = rewriter->create<mlir::DimOp>(loc, small, indexes[i]);
|
||||
CmpIOp compare =
|
||||
rewriter->create<mlir::CmpIOp>(loc, CmpIPredicate::eq, lrg_dim, one);
|
||||
shape_values.push_back(
|
||||
rewriter->create<mlir::SelectOp>(loc, compare, lrg_dim, sml_dim));
|
||||
}
|
||||
|
||||
return shape_values;
|
||||
}
|
||||
|
||||
// Helper function for OpRewritePattern classes to materialize dynamic
|
||||
// broadcasts on LHS and RHS arguments to a binary op.
|
||||
//
|
||||
// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts
|
||||
// for LHS and RHS arguments, else returns false.
|
||||
template <typename SrcOp>
|
||||
bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
|
||||
Value *out_lhs, Value *out_rhs) {
|
||||
if (!op.broadcast_dimensions().hasValue()) {
|
||||
// Note: the op may still have an implicit broadcast on it, such as
|
||||
// for (tensor<1xf32>, tensor<4xf32>).
|
||||
return false;
|
||||
}
|
||||
|
||||
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
|
||||
// replacing the original LHS and RHS args in the source op with the results
|
||||
// of the broadcasts.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhs_ranked_type || !rhs_ranked_type) {
|
||||
// Unranked, can't determine at this point how to perform the broadcast.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto lhs_rank = lhs_ranked_type.getRank();
|
||||
auto rhs_rank = rhs_ranked_type.getRank();
|
||||
|
||||
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
|
||||
// Use the original op.broadcast_dimensions for the lower rank arg.
|
||||
auto higher_rank_broadcast_dims =
|
||||
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
|
||||
DenseIntElementsAttr lhs_broadcast_dims;
|
||||
DenseIntElementsAttr rhs_broadcast_dims;
|
||||
std::vector<Value> shape_elements;
|
||||
if (lhs_rank > rhs_rank) {
|
||||
lhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
shape_elements = ComputeBroadcastedShape<SrcOp>(op, rhs, lhs, rewriter);
|
||||
} else if (lhs_rank < rhs_rank) {
|
||||
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
rhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
shape_elements = ComputeBroadcastedShape<SrcOp>(op, lhs, rhs, rewriter);
|
||||
} else {
|
||||
// This shouldn't happen for legal ops. If the broadcast_dimensions
|
||||
// attribute is set, the ranks should be different.
|
||||
// TODO(scotttodd): Add a custom verification for ops and assert here.
|
||||
return false;
|
||||
}
|
||||
|
||||
// DynamicBroadcastInDimOp preserves the element type but produces a tensor
|
||||
// with unranked shape. The rank of the output is the length of the
|
||||
// output shape argument.
|
||||
SmallVector<int64_t, 4> op_shape(shape_elements.size(),
|
||||
RankedTensorType::kDynamicSize);
|
||||
auto lhs_type =
|
||||
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
|
||||
auto rhs_type =
|
||||
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
|
||||
|
||||
// We need a way to turn a list of scalars into a vector. While Standard
|
||||
// dialect does not have one, use the XLA_HLO variant.
|
||||
int shape_size = shape_elements.size();
|
||||
Type shape_element_type = shape_elements.front().getType();
|
||||
Value shape_value = rewriter->create<ScalarsToDimensionTensorOp>(
|
||||
op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type),
|
||||
shape_elements);
|
||||
|
||||
*out_lhs = rewriter->createOrFold<DynamicBroadcastInDimOp>(
|
||||
op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims);
|
||||
*out_rhs = rewriter->createOrFold<DynamicBroadcastInDimOp>(
|
||||
op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename SrcOp>
|
||||
bool CreateBroadcastForBinaryOp(SrcOp op, PatternRewriter *rewriter,
|
||||
Value *out_lhs, Value *out_rhs) {
|
||||
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!op_ranked_type) return false;
|
||||
|
||||
if (op_ranked_type.hasStaticShape()) {
|
||||
if (!CreateStaticBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!CreateDynamicBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename SrcOp>
|
||||
struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
|
||||
explicit BinaryOpWithBroadcastConvert(MLIRContext *context)
|
||||
: OpRewritePattern<SrcOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value new_lhs;
|
||||
Value new_rhs;
|
||||
|
||||
if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs))
|
||||
return failure();
|
||||
|
||||
// Replace the original op with a new one that uses the new args.
|
||||
// New args are broadcasts, so no dims are needed on the replacement op.
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), new_lhs, new_rhs,
|
||||
/*broadcast_dims=*/nullptr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays
|
||||
// must be the same shape. Alternatively, as a restricted form of broadcasting,
|
||||
// min and/or max can be a scalar of type T."
|
||||
@ -322,63 +69,10 @@ struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized class for CompareOp, as it has an additional builder argument.
|
||||
struct CompareWithBroadcastConvert : public OpRewritePattern<CompareOp> {
|
||||
explicit CompareWithBroadcastConvert(MLIRContext *context)
|
||||
: OpRewritePattern<CompareOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value new_lhs;
|
||||
Value new_rhs;
|
||||
|
||||
if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<CompareOp>(op, op.getType(), new_lhs, new_rhs,
|
||||
/*broadcast_dims=*/nullptr,
|
||||
op.comparison_direction());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||
ConversionTarget *conversionTarget) {
|
||||
#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
|
||||
conversionTarget->addDynamicallyLegalOp<OpType>([](OpType op) { \
|
||||
if (op.broadcast_dimensions().hasValue()) return false; \
|
||||
auto l = op.lhs().getType().cast<ShapedType>(); \
|
||||
auto r = op.rhs().getType().cast<ShapedType>(); \
|
||||
if (!l.hasRank() || !r.hasRank()) return false; \
|
||||
return l.getShape() == r.getShape(); \
|
||||
});
|
||||
|
||||
// Binary elementwise ops.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp);
|
||||
|
||||
// Binary logical elementwise ops.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp);
|
||||
|
||||
// CompareOp.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp);
|
||||
|
||||
#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST
|
||||
|
||||
conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) {
|
||||
return op.max().getType() == op.operand().getType() &&
|
||||
op.min().getType() == op.operand().getType();
|
||||
@ -387,30 +81,10 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||
|
||||
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
// Binary elementwise ops.
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<AddOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<Atan2Op>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<DivOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MaxOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MinOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MulOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<PowOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<RemOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftLeftOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightArithmeticOp>>(
|
||||
context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightLogicalOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<SubOp>>(context);
|
||||
|
||||
// Binary logical elementwise ops.
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<AndOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<OrOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<XorOp>>(context);
|
||||
|
||||
// ClampOp. It can have a restricted form of broadcasting.
|
||||
// ClampOp. This op has a special case where it accepts either same-shaped
|
||||
// inputs or scalars (a restricted form of broadcasting). This makes the
|
||||
// broadcast explicit.
|
||||
patterns->insert<ClampWithBroadcastConvert>(context);
|
||||
// CompareOp. Note the specialized class instead of using the template.
|
||||
patterns->insert<CompareWithBroadcastConvert>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
@ -36,7 +36,7 @@ namespace xla_hlo {
|
||||
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
|
||||
/// false, emits an error if there is any operation that can't be legalized.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
bool allow_partial_conversion = false);
|
||||
bool allow_partial_conversion = false, bool legalize_chlo = true);
|
||||
|
||||
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
|
||||
/// specified device type.
|
||||
@ -50,7 +50,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
|
||||
/// dialect using the conversion patterns registered by the HLO dialect. When
|
||||
/// allow_partial_conversion is false, emits an error if there is any operation
|
||||
/// that can't be legalized.
|
||||
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false);
|
||||
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false,
|
||||
bool legalize_chlo = true);
|
||||
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
|
@ -135,8 +135,8 @@ class UnfuseBatchNormInferencePattern
|
||||
if (!epsilon) {
|
||||
return failure();
|
||||
}
|
||||
Value stddev = rewriter.create<xla_hlo::AddOp>(
|
||||
bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr);
|
||||
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
|
||||
bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
|
||||
// Broadcast all terms.
|
||||
@ -160,13 +160,13 @@ class UnfuseBatchNormInferencePattern
|
||||
// Compute:
|
||||
// scale * (input - mean) / stddev + offset
|
||||
Value result = rewriter.create<xla_hlo::SubOp>(
|
||||
bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr);
|
||||
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
|
||||
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
|
||||
broadcast_scale, nullptr);
|
||||
broadcast_scale);
|
||||
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
|
||||
broadcast_stddev, nullptr);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result, broadcast_offset,
|
||||
nullptr);
|
||||
broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
|
||||
broadcast_offset);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user