From b7735095de23aa2aac940a984a24f25f8c26395c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 May 2020 11:17:02 -0700 Subject: [PATCH] Remove implicit broadcasting from xla_hlo binary elementwise ops. * Migrates legalize-tf conversions to either: * convert through the chlo.broadcast_* ops (majority) * have special case broadcasting for non-supported or non-optimal broadcast modes * This was done one by one, qc'ing each and many bugs/inefficiencies/ambiguous broadcasting modes were corrected. * Looks like it may be missing a rule for legalizing complex types (will check on Monday). * I considered splitting this up, but it was actually pretty important to make the ops more strict to flush out all cases (best done as an atomic change). * Stricter conversions fixed a number of cases where shapes were dropping to unranked or unknown (and needn't be). * With this change, most of the binary ops and many of the resulting tf2xla expansions correctly support dynamic shapes via the shape dialect. * I verified this with the small set of IREE test cases which support dynamic shapes and will expand coverage once this lands. * This is some test fallout outside of the xla directory that I will fixup on Monday. PiperOrigin-RevId: 312316083 Change-Id: I6d246d80cddb84f2dfd62817c7166f53c1f6cdec --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 98 +-- .../tensorflow/transforms/legalize_hlo.cc | 1 + .../transforms/legalize_hlo_patterns.td | 70 +- tensorflow/compiler/mlir/xla/ir/chlo_ops.cc | 10 + tensorflow/compiler/mlir/xla/ir/chlo_ops.td | 5 + tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 88 +- tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 51 +- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 1 - .../tests/legalize-tf-binary-elementwise.mlir | 334 ++++++++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 780 ++++-------------- .../mlir/xla/tests/legalize-to-std.mlir | 36 +- .../mlir/xla/tests/lower-complex.mlir | 90 +- .../xla/tests/materialize-broadcasts.mlir | 280 ------- .../mlir/xla/tests/translate/export.mlir | 30 +- .../mlir/xla/tests/translate/import.hlotxt | 47 +- .../xla/transforms/chlo_legalize_to_hlo.cc | 9 +- .../mlir/xla/transforms/legalize_tf.cc | 488 +++++++---- .../xla/transforms/legalize_tf_patterns.td | 80 +- .../legalize_to_standard_patterns.td | 33 +- .../xla/transforms/lower_complex_patterns.td | 58 +- .../xla/transforms/materialize_broadcasts.cc | 332 +------- .../compiler/mlir/xla/transforms/passes.h | 5 +- .../mlir/xla/transforms/unfuse_batch_norm.cc | 14 +- 23 files changed, 1056 insertions(+), 1884 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 10cb4f8019d..7691a6bd6e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,17 +2,17 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } @@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor) -> tensor { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = xla_hlo.constant dense<0> : tensor<3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = xla_hlo.constant dense<1> : tensor<3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } @@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %0 = xla_hlo.constant dense<0> : tensor<3xi32> %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> @@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } @@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> { func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %3 : tensor } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 50f77cd9c3d..b1cbc41a03e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index f3371989b73..6fd7556084d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -18,12 +18,16 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; //===----------------------------------------------------------------------===// // Binary op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// // Check that two values can be broadcasted together @@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; def AreBroadcastCompatible : Constraint, "types must be broadcastable">; -foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], - [HLO_DivOp, TF_DivOp], - [HLO_ShiftLeftOp, TF_LeftShiftOp], - [HLO_MaxOp, TF_MaximumOp], - [HLO_MinOp, TF_MinimumOp], - [HLO_MulOp, TF_MulOp], - [HLO_PowOp, TF_PowOp], - [HLO_SubOp, TF_SubOp], - [HLO_Atan2Op, TF_Atan2Op], - [HLO_RemOp, TF_ModOp]] in - def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r), +foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op], + [HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp], + [HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp], + [HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp], + [HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp], + [HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp], + [HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp], + [HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp], + [HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op], + [HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in { + def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>; + def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_BitwiseAndOp], - [HLO_OrOp, TF_BitwiseOrOp], - [HLO_XorOp, TF_BitwiseXorOp]] in - def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp], + [HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in { + def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_LogicalAndOp], - [HLO_OrOp, TF_LogicalOrOp]] in - def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in { + def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; +def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; @@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// // Compare op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], - [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in - def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>; +} foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], - [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in - def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>; +} diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index 5322668aa2e..26db4549a2a 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -185,6 +185,16 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( // BroadcastCompareOp (has custom type inference due to different result type). //===----------------------------------------------------------------------===// +void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, + Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dimensions, + StringAttr comparison_direction) { + auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), + builder.getI1Type(), broadcast_dimensions); + build(builder, result, new_type, lhs, rhs, broadcast_dimensions, + comparison_direction); +} + LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td index f9672c1a95a..febc99f6b72 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td @@ -360,6 +360,11 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_ComparisonDirectionAttr:$comparison_direction ); let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + >]; } #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index b6036ee2130..03928467cff 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -1401,89 +1401,25 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// namespace { -// Gets the resulting type from a broadcast between two types. -static Type GetBroadcastType(Builder* builder, Type x, Type y, - Type element_type, - DenseIntElementsAttr broadcast_dimensions) { + +// Updates the element type of a (presumed) tensor type 'x', returning either +// a permuted UnrankedTensorType or RankedTensorType. +static Type UpdateResultElementType(Builder* builder, Type x, + Type element_type) { auto x_ranked = x.dyn_cast(); - auto y_ranked = y.dyn_cast(); - if (!x_ranked || !y_ranked) { + if (!x_ranked) { return UnrankedTensorType::get(element_type); } auto shape_x = x_ranked.getShape(); - auto shape_y = y_ranked.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - if (x_val == -1 || y_val == -1) { - out_shape[i] = -1; - } else { - out_shape[i] = std::max(x_val, y_val); - } - } - return RankedTensorType::get(out_shape, element_type); - } - - // Return unranked tensor for invalid broadcast dimensions. - if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - auto old_value = out_shape[index_pair.value().getSExtValue()]; - auto new_value = shape_small[index_pair.index()]; - if (old_value != -1 && (new_value == -1 || new_value > old_value)) { - out_shape[index_pair.value().getSExtValue()] = new_value; - } - } - - return RankedTensorType::get(out_shape, element_type); + return RankedTensorType::get(shape_x, element_type); } } // namespace -#define BINARY_BUILDER(Op) \ - void Op::build(OpBuilder& builder, OperationState& result, Value left, \ - Value right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(&builder, left.getType().cast(), \ - right.getType().cast(), \ - getElementTypeOrSelf(right.getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ - } - -BINARY_BUILDER(AddOp); -BINARY_BUILDER(AndOp); -BINARY_BUILDER(Atan2Op); -BINARY_BUILDER(DivOp); -BINARY_BUILDER(MaxOp); -BINARY_BUILDER(MinOp); -BINARY_BUILDER(MulOp); -BINARY_BUILDER(OrOp); -BINARY_BUILDER(PowOp); -BINARY_BUILDER(RemOp); -BINARY_BUILDER(ShiftLeftOp); -BINARY_BUILDER(ShiftRightArithmeticOp); -BINARY_BUILDER(ShiftRightLogicalOp); -BINARY_BUILDER(SubOp); -BINARY_BUILDER(XorOp); - -#undef BINARY_BUILDER - template static Attribute BinaryFolder(Op* op, ArrayRef attrs) { if (!attrs[0] || !attrs[1]) return {}; - if (op->broadcast_dimensions().hasValue()) return {}; DenseElementsAttr lhs = attrs[0].dyn_cast(); DenseElementsAttr rhs = attrs[1].dyn_cast(); @@ -1893,12 +1829,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, - Value rhs, DenseIntElementsAttr broadcast_dimensions, - StringAttr comparison_direction) { - auto new_type = GetBroadcastType(&builder, lhs.getType(), rhs.getType(), - builder.getI1Type(), broadcast_dimensions); - build(builder, result, new_type, lhs, rhs, broadcast_dimensions, - comparison_direction); + Value rhs, StringAttr comparison_direction) { + auto new_type = + UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); + build(builder, result, new_type, lhs, rhs, comparison_direction); } #define GET_OP_CLASSES diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 5d46140c3ea..99801f1618e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -241,15 +241,9 @@ class HLO_BinaryElementwiseOp traits> : HLO_Op { let arguments = (ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_Tensor:$rhs ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; - let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, @@ -270,15 +264,15 @@ class HLO_BinaryElementwiseOp traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { let hasFolder = 1; } def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, + [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ComplexOp { let builders = [OpBuilder< "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; @@ -289,39 +283,39 @@ def HLO_ComplexOp: HLO_Op<"complex", } def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { } def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { } def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { } def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { let hasFolder = 1; } def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { let hasFolder = 1; } @@ -331,11 +325,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryLogicalElementwiseOp : - HLO_BinaryElementwiseOp { + HLO_BinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_PredOrIntTensor:$rhs ); } @@ -617,23 +611,18 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { } def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp { + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, + BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions, " - "StringAttr comparison_direction" - >]; let results = (outs HLO_PredTensor); let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + "StringAttr comparison_direction" >]; } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 461c357e509..774caab77fb 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -209,7 +209,6 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, shape, builder_)); auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), - /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(), builder_.getStringAttr(ComparisonDirectionToString(direction))); return MakeXlaOp(op.getResult()); } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir new file mode 100644 index 00000000000..c114b8c50a5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -0,0 +1,334 @@ +// Note that binary elementwise tests are run with chlo legalization enabled +// (unlike the rest), since this is the primary use case for such ops and +// verification of shapes and broadcasts is desired. +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// Binary op legalizations. +// Most of these expand from the same pattern. Full semantics are +// verified for tf.Add and pattern application only for the rest. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @add +func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> + %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %1: tensor<2xi32> +} + +// CHECK-LABEL: func @broadcast_add +// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check +// patterns unambiguous and more interesting (once broadcastable trait is +// fixed upstream). +func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0: tensor<1x2xi32> +} + +// CHECK-LABEL: func @broadcast_multi_dim_add +// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream +// broadcastable bug is fixed (helps make the CHECK matching unambiguous) +func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + return %0: tensor<4x4x4x4xi32> +} + +// CHECK-LABEL: func @add_dynamic +func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @div +func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_left +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @div_unranked +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { + // CHECK: tf.Div + %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @maximum +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @minimum +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @mul +func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @real_div +func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @sub +func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_right +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + return %0 : tensor<2x4xui8> +} + +// CHECK-LABEL: func @and +func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @and_unranked +func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: tf.LogicalAnd + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @or +func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_and +func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @pow +func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: xla_hlo.power + %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0: tensor<2xf32> +} + +//===----------------------------------------------------------------------===// +// Equality op legalizations. +// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are +// verified for tf.Equal and pattern application only for tf.NotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @equal +func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @equal_dynamic +func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_broadcast +func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error +func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable +func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_incompatible_shape_dynamic +func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic +func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_unranked +func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @notequal +func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +//===----------------------------------------------------------------------===// +// Compare op legalizations. +// These expand from the same pattern. Full semantics are checked for +// tf.Greater. Others just check that the pattern applied. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @greater +func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @broadcast_greater +func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @greater_dynamic +func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @greater_uranked +func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Greater" + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @greater_equal +func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less +func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less_equal +func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index d5440a024ab..bfa96413e7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,4 +1,11 @@ -// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s +// This test runs twice: +// 1. Through FileCheck with chlo legalization disabled since verifying +// that the chlo ops emit produces more useful tests. +// 2. With chlo legalization enabled, verifying diagnostics to pick up any +// issues with the full lowering (can catch some broadcasting corner +// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -47,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: xla_hlo.constant - // CHECK: "xla_hlo.multiply"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } @@ -68,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, // CHECK-DAG: %[[BATCH_VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} // CHECK: %[[FACTOR:.*]] = xla_hlo.constant dense<1.00195694> - // CHECK: %[[CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BATCH_VAR]], %[[FACTOR]]) + // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] // CHECK-DAG: %[[ALPHA:.*]] = xla_hlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = xla_hlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg3) - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[BATCH_MEAN]]) - // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg4) - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[CORRECTED_VAR]]) - // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> @@ -127,11 +134,12 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -142,10 +150,10 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -185,11 +193,12 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -200,10 +209,11 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -270,11 +280,12 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -285,10 +296,11 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -355,11 +367,12 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -370,10 +383,11 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -405,207 +419,41 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_dynamic func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor return %0 : tensor } //===----------------------------------------------------------------------===// -// Binary op legalizations. -// Most of these expand from the same pattern. Full semantics are -// verified for tf.Add and pattern application only for the rest. +// DiagPart //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @add -func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> - // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %1: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_add -// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check -// patterns unambiguous and more interesting (once broadcastable trait is -// fixed upstream). -func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @broadcast_multi_dim_add -// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream -// broadcastable bug is fixed (helps make the CHECK matching unambiguous) -func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] - // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] - // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} - // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> - return %0: tensor<4x4x4x4xi32> -} - -// CHECK-LABEL: func @add_dynamic -func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: xla_hlo.add %4, %5 : tensor - %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @div -func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @shift_left -func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> - %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @div_unranked -func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { - // CHECK: tf.Div - %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @maximum -func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @minimum -func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @mul -func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @real_div -func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @sub -func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @shift_right -func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @shift_right_unsigned -func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> - return %0 : tensor<4xui8> -} - -// CHECK-LABEL: func @broadcast_shift_right_unsigned -func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> - return %0 : tensor<2x4xui8> -} - -// CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @and_unranked -func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { - // CHECK: tf.LogicalAnd - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @bitwise_or -func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_and -func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @pow -func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { @@ -625,6 +473,10 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { return %0: tensor<4x3xf32> } +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @einsum func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { // CHECK: xla_hlo.einsum @@ -639,22 +491,26 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { return %0: tensor<2x2xf32> } +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -664,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -685,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -696,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: return @@ -706,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -729,15 +585,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_hlo.add %arg1, [[REM]] + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR:%.+]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM:%.+]], [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -746,15 +602,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"(%arg1, [[REM]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR:%.+]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM:%.+]], [[ZR]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -775,6 +631,10 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x return %0: tensor<*xi32> } +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @broadcast_to func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> @@ -787,155 +647,6 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { return %0 : tensor<16x16x16x16xf32> } -//===----------------------------------------------------------------------===// -// Equality op legalizations. -// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are -// verified for tf.Equal and pattern application only for tf.NotEqual -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @equal_dynamic -func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_broadcast -func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error -func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable -func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_incompatible_shape_dynamic -func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic -func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_unranked -func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Equal" - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -//===----------------------------------------------------------------------===// -// Compare op legalizations. -// These expand from the same pattern. Full semantics are checked for -// tf.Greater. Others just check that the pattern applied. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater -func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @greater_dynamic -func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @greater_uranked -func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - - //===----------------------------------------------------------------------===// // Complex op legalizations. //===----------------------------------------------------------------------===// @@ -1224,12 +935,12 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<*xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<64x64xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1245,11 +956,11 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2 // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<*xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<24x48xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1396,7 +1107,8 @@ func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: // CHECK-LABEL:one_hot func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%arg0, %[[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[BCAST_ARG0:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "xla_hlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "xla_hlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> @@ -1561,7 +1273,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1569,7 +1281,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1597,8 +1309,8 @@ func @relu6_unranked(%arg0: tensor) -> tensor { func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor<*xi1> - // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> @@ -1708,7 +1420,10 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1720,8 +1435,11 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.divide"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -1730,7 +1448,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.divide"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_hlo.divide {{.*}} : tensor %0 = "tf.Softmax"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1756,43 +1474,29 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK: "xla_hlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.divide"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: xla_hlo.divide {{.*}} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } //===----------------------------------------------------------------------===// // LogSoftmax op legalizations. +// This just changes the tail of the regular Softmax legalization //===----------------------------------------------------------------------===// // CHECK-LABEL: func @simple_logsoftmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - - // Verify reduce op for max computation and its body. - // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.maximum - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) - - // Verify reduce op for summation and its body. - // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %{{.*}} = "xla_hlo.reduce"({{.*}}) + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"({{.*}}) // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[RESULT:.*]] = "xla_hlo.subtract"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -2643,10 +2347,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]]) + // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor + // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" @@ -2775,7 +2479,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = "xla_hlo.divide"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -3079,8 +2783,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -3092,12 +2796,12 @@ func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -3392,13 +3096,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_hlo.multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_hlo.multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_hlo.multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3915,7 +3619,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDICES1:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = xla_hlo.add [[IV]], [[ONE]] + // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] // CHECK: [[NEW_TUPLE:%.*]] = "xla_hlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) // CHECK: "xla_hlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tuple, tensor<4xi32>, tensor<4xi32>> @@ -3984,7 +3688,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "xla_hlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = "xla_hlo.divide"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> @@ -4124,177 +3828,11 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor // CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { -// CHECK: [[VAL_1:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_2:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_3:%.*]] = "xla_hlo.compare"([[VAL_1]], [[VAL_2]]) {comparison_direction = "EQ"} : (tensor<100x100xi32>, tensor<100x100xi32>) -> tensor<100x100xi1> -// CHECK: [[VAL_4:%.*]] = "xla_hlo.convert"([[VAL_3]]) : (tensor<100x100xi1>) -> tensor<100x100xf32> -// CHECK: [[VAL_5:%.*]] = "xla_hlo.broadcast"([[VAL_4]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor<100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_6:%.*]] = "xla_hlo.slice"([[VAL_0]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_7:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_8:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_9:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor) -> tensor<500x75xf32> -// CHECK: [[VAL_10:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_11:%.*]] = "xla_hlo.tuple"([[VAL_10]], [[VAL_6]], [[VAL_8]], [[VAL_9]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_12:%.*]] = "xla_hlo.while"([[VAL_11]]) ( { -// CHECK: ^bb0([[VAL_13:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_14:%.*]] = "xla_hlo.get_tuple_element"([[VAL_13]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_15:%.*]] = xla_hlo.constant dense<75> : tensor -// CHECK: [[VAL_16:%.*]] = "xla_hlo.compare"([[VAL_14]], [[VAL_15]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor -// CHECK: "xla_hlo.return"([[VAL_16]]) : (tensor) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_17:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_18:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_19:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_20:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_21:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_22:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_23:%.*]] = "xla_hlo.dynamic-slice"([[VAL_19]], [[VAL_22]], [[VAL_22]], [[VAL_18]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x1xf32> -// CHECK: [[VAL_24:%.*]] = "xla_hlo.reshape"([[VAL_23]]) : (tensor<500x100x1xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_25:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_26:%.*]] = xla_hlo.constant dense<1.000000e+00> : tensor -// CHECK: [[VAL_27:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_28:%.*]] = "xla_hlo.dynamic-slice"([[VAL_24]], [[VAL_27]], [[VAL_18]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x100xf32>, tensor, tensor) -> tensor<500x1xf32> -// CHECK: [[VAL_29:%.*]] = "xla_hlo.reshape"([[VAL_28]]) : (tensor<500x1xf32>) -> tensor<500xf32> -// CHECK: [[VAL_30:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100xi32> -// CHECK: [[VAL_31:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<100xi32>, tensor) -> tensor<100xi1> -// CHECK: [[VAL_32:%.*]] = "xla_hlo.convert"([[VAL_31]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_33:%.*]] = "xla_hlo.multiply"([[VAL_24]], [[VAL_32]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_34:%.*]] = xla_hlo.multiply [[VAL_33]], [[VAL_33]] : tensor<500x100xf32> -// CHECK: [[VAL_35:%.*]] = "xla_hlo.reduce"([[VAL_34]], [[VAL_25]]) ( { -// CHECK: ^bb0([[VAL_36:%.*]]: tensor, [[VAL_37:%.*]]: tensor): -// CHECK: [[VAL_38:%.*]] = xla_hlo.add [[VAL_36]], [[VAL_37]] : tensor -// CHECK: "xla_hlo.return"([[VAL_38]]) : (tensor) -> () -// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor) -> tensor<500xf32> -// CHECK: [[VAL_39:%.*]] = xla_hlo.multiply [[VAL_29]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_40:%.*]] = xla_hlo.add [[VAL_39]], [[VAL_41:%.*]] : tensor<500xf32> -// CHECK: [[VAL_42:%.*]] = "xla_hlo.sqrt"([[VAL_40]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_43:%.*]] = "xla_hlo.compare"([[VAL_41]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500xf32>, tensor) -> tensor<500xi1> -// CHECK: [[VAL_44:%.*]] = "xla_hlo.compare"([[VAL_29]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<500xf32>, tensor) -> tensor<500xi1> -// CHECK: [[VAL_45:%.*]] = "xla_hlo.broadcast"([[VAL_26]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor) -> tensor<500xf32> -// CHECK: [[VAL_46:%.*]] = "xla_hlo.negate"([[VAL_45]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_47:%.*]] = "xla_hlo.select"([[VAL_44]], [[VAL_45]], [[VAL_46]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_48:%.*]] = xla_hlo.multiply [[VAL_47]], [[VAL_42]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<500xf32> -// CHECK: [[VAL_49:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_29]], [[VAL_48]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_50:%.*]] = xla_hlo.subtract [[VAL_49]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_51:%.*]] = xla_hlo.divide [[VAL_50]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_52:%.*]] = "xla_hlo.broadcast"([[VAL_25]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor) -> tensor<500xf32> -// CHECK: [[VAL_53:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_52]], [[VAL_51]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_54:%.*]] = xla_hlo.subtract [[VAL_29]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_55:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_45]], [[VAL_54]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_56:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100xi32>, tensor) -> tensor<100xi1> -// CHECK: [[VAL_57:%.*]] = "xla_hlo.convert"([[VAL_56]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_58:%.*]] = "xla_hlo.broadcast"([[VAL_57]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100xf32>) -> tensor<1x100xf32> -// CHECK: [[VAL_59:%.*]] = "xla_hlo.divide"([[VAL_33]], [[VAL_55]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<500xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_60:%.*]] = "xla_hlo.add"([[VAL_58]], [[VAL_59]]) : (tensor<1x100xf32>, tensor<500x100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_61:%.*]] = "xla_hlo.reshape"([[VAL_60]]) : (tensor<500x100xf32>) -> tensor<500x1x100xf32> -// CHECK: [[VAL_62:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_19]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x100x75xf32>) -> tensor<500x1x75xf32> -// CHECK: [[VAL_63:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_62]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x1x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_64:%.*]] = "xla_hlo.multiply"([[VAL_53]], [[VAL_63]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_65:%.*]] = xla_hlo.subtract [[VAL_19]], [[VAL_64]] : tensor<500x100x75xf32> -// CHECK: [[VAL_66:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x1xi32> -// CHECK: [[VAL_67:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<100x1xi32>, tensor) -> tensor<100x1xi1> -// CHECK: [[VAL_68:%.*]] = "xla_hlo.convert"([[VAL_67]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_69:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100x1xi32>, tensor) -> tensor<100x1xi1> -// CHECK: [[VAL_70:%.*]] = "xla_hlo.convert"([[VAL_69]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_71:%.*]] = "xla_hlo.broadcast"([[VAL_70]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100x1xf32>) -> tensor<1x100x1xf32> -// CHECK: [[VAL_72:%.*]] = "xla_hlo.multiply"([[VAL_23]], [[VAL_68]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<500x100x1xf32>, tensor<100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_73:%.*]] = "xla_hlo.multiply"([[VAL_49]], [[VAL_71]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<1x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_74:%.*]] = xla_hlo.add [[VAL_72]], [[VAL_73]] : tensor<500x100x1xf32> -// CHECK: [[VAL_75:%.*]] = "xla_hlo.broadcast_in_dim"([[VAL_74]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<500x100x1xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_76:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_77:%.*]] = "xla_hlo.compare"([[VAL_76]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x100x75xi32>, tensor) -> tensor<500x100x75xi1> -// CHECK: [[VAL_78:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_75]], [[VAL_65]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_79:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_80:%.*]] = "xla_hlo.broadcast"([[VAL_79]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_81:%.*]] = "xla_hlo.add"([[VAL_80]], [[VAL_60]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<500x100x75xf32>, tensor<500x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_82:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_81]], [[VAL_80]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_83:%.*]] = xla_hlo.add [[VAL_20]], [[VAL_82]] : tensor<500x100x75xf32> -// CHECK: [[VAL_84:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<500x75xi32> -// CHECK: [[VAL_85:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_86:%.*]] = "xla_hlo.broadcast"([[VAL_85]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor) -> tensor<500x75xf32> -// CHECK: [[VAL_87:%.*]] = "xla_hlo.compare"([[VAL_84]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x75xi32>, tensor) -> tensor<500x75xi1> -// CHECK: [[VAL_88:%.*]] = "xla_hlo.add"([[VAL_86]], [[VAL_53]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x75xf32>, tensor<500xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_89:%.*]] = "xla_hlo.select"([[VAL_87]], [[VAL_88]], [[VAL_86]]) : (tensor<500x75xi1>, tensor<500x75xf32>, tensor<500x75xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_90:%.*]] = xla_hlo.add [[VAL_21]], [[VAL_89]] : tensor<500x75xf32> -// CHECK: [[VAL_91:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_92:%.*]] = xla_hlo.add [[VAL_18]], [[VAL_91]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor -// CHECK: [[VAL_93:%.*]] = "xla_hlo.tuple"([[VAL_92]], [[VAL_78]], [[VAL_83]], [[VAL_90]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_93]]) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_94:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95:%.*]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_96:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_97:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_98:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_99:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_100:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_101:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_0]], [[VAL_94]], [[VAL_100]], [[VAL_98]], [[VAL_99]]) : (tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_102:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_103:%.*]] = "xla_hlo.broadcast"([[VAL_102]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_104:%.*]] = "xla_hlo.slice"([[VAL_96]]) {limit_indices = dense<[500, 100, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_105:%.*]] = "xla_hlo.slice"([[VAL_97]]) {limit_indices = dense<[500, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<500x75xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_106:%.*]] = "xla_hlo.negate"([[VAL_105]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_107:%.*]] = "xla_hlo.multiply"([[VAL_106]], [[VAL_104]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_108:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_109:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_110:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_103]], [[VAL_107]], [[VAL_109]], [[VAL_109]], [[VAL_108]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_111:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_112:%.*]] = "xla_hlo.tuple"([[VAL_111]], [[VAL_110]], [[VAL_96]], [[VAL_97]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_113:%.*]] = "xla_hlo.while"([[VAL_112]]) ( { -// CHECK: ^bb0([[VAL_114:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_115:%.*]] = "xla_hlo.get_tuple_element"([[VAL_114]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_116:%.*]] = xla_hlo.constant dense<74> : tensor -// CHECK: [[VAL_117:%.*]] = "xla_hlo.compare"([[VAL_115]], [[VAL_116]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor -// CHECK: "xla_hlo.return"([[VAL_117]]) : (tensor) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_118:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_119:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_120:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_121:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_122:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_123:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_124:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_123]] : tensor -// CHECK: [[VAL_125:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_126:%.*]] = "xla_hlo.dynamic-slice"([[VAL_121]], [[VAL_125]], [[VAL_125]], [[VAL_124]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x1xf32> -// CHECK: [[VAL_127:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_128:%.*]] = "xla_hlo.dynamic-slice"([[VAL_122]], [[VAL_127]], [[VAL_124]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x75xf32>, tensor, tensor) -> tensor<500x1xf32> -// CHECK: [[VAL_129:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_130:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_131:%.*]] = "xla_hlo.broadcast"([[VAL_130]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_132:%.*]] = "xla_hlo.compare"([[VAL_129]], [[VAL_124]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GE"} : (tensor<500x100x75xi32>, tensor) -> tensor<500x100x75xi1> -// CHECK: [[VAL_133:%.*]] = "xla_hlo.select"([[VAL_132]], [[VAL_131]], [[VAL_121]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_134:%.*]] = "xla_hlo.dot_general"([[VAL_133]], [[VAL_126]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x1xf32>) -> tensor<500x75x1xf32> -// CHECK: [[VAL_135:%.*]] = "xla_hlo.dot_general"([[VAL_120]], [[VAL_134]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_136:%.*]] = "xla_hlo.negate"([[VAL_128]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_137:%.*]] = xla_hlo.add [[VAL_126]], [[VAL_135]] : tensor<500x100x1xf32> -// CHECK: [[VAL_138:%.*]] = "xla_hlo.multiply"([[VAL_136]], [[VAL_137]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_139:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_140:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_120]], [[VAL_138]], [[VAL_139]], [[VAL_139]], [[VAL_124]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_141:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_142:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_141]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor -// CHECK: [[VAL_143:%.*]] = "xla_hlo.tuple"([[VAL_142]], [[VAL_140]], [[VAL_121]], [[VAL_122]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_143]]) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_144:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145:%.*]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_146:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_147:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_148:%.*]] = "xla_hlo.slice"([[VAL_101]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<[0, 0, 75]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_149:%.*]] = "xla_hlo.dot_general"([[VAL_144]], [[VAL_148]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x0xf32>) -> tensor<500x75x0xf32> -// CHECK: [[VAL_150:%.*]] = "xla_hlo.dot_general"([[VAL_96]], [[VAL_149]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x0xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_151:%.*]] = xla_hlo.add [[VAL_148]], [[VAL_150]] : tensor<500x100x0xf32> -// CHECK: [[VAL_152:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_153:%.*]] = xla_hlo.constant dense<75> : tensor -// CHECK: [[VAL_154:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_155:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_101]], [[VAL_151]], [[VAL_154]], [[VAL_152]], [[VAL_153]]) : (tensor<500x100x75xf32>, tensor<500x100x0xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_156:%.*]] = "xla_hlo.slice"([[VAL_5]]) {limit_indices = dense<[500, 100, 100]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_157:%.*]] = "xla_hlo.dot_general"([[VAL_156]], [[VAL_144]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x100xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_158:%.*]] = "xla_hlo.dot_general"([[VAL_157]], [[VAL_96]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_159:%.*]] = xla_hlo.add [[VAL_156]], [[VAL_158]] : tensor<500x100x100xf32> -// CHECK: [[VAL_160:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_161:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_162:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_5]], [[VAL_159]], [[VAL_161]], [[VAL_161]], [[VAL_160]]) : (tensor<500x100x100xf32>, tensor<500x100x100xf32>, tensor, tensor, tensor) -> tensor<500x100x100xf32> -// CHECK: [[VAL_163:%.*]] = "xla_hlo.slice"([[VAL_162]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_164:%.*]] = "xla_hlo.slice"([[VAL_155]]) {limit_indices = dense<[500, 75, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x75x75xf32> -// CHECK: return [[VAL_163]], [[VAL_164]] : tensor<500x100x75xf32>, tensor<500x75x75xf32> + // The tf.Qr lowering is a full algorithm that is not effective to verify with + // FileCheck. Just verify that it converted. + // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is + // really only applicable to certain legacy uses. + // CHECK-NOT: "tf.Qr" %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index d25a84d0e25..9f27a204baf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { @@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } -// Broadcasting is not currently supported. -// TODO(suderman):Future pass should take all broadcasted binary ops and convert -// them to separate broadcast and binary op. -// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { -func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) { - name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) { - name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) { - name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) { - name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) { - broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: return %4 : tensor<4x4xf32> - return %4 : tensor<4x4xf32> -} - // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 35a5ae549d5..81376761467 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s +// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { @@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @add_broadcast -func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @add_unranked func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @sub_broadcast -func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @sub_unranked func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @mul_broadcast -func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32> - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @mul_unranked func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // ----- -// CHECK-LABEL: @div_broadcast -func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) - // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]]) - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - -// ----- - // CHECK-LABEL: @div_unranked func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index a7f4a5b4474..55b55c7b4e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -1,225 +1,5 @@ // RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s -// CHECK-LABEL: @addBroadcastRhs -func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastLhs -func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastEqual -func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32> - return %0 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastMultidimension -func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32> - return %0 : tensor<1x1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastBothArgs -func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - return %0 : tensor<3x2x2xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastScalar -func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addWithoutBroadcast -func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addUnranked -func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @atan2BroadcastRhs -func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @divBroadcastRhs -func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @maxBroadcastRhs -func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @minBroadcastRhs -func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @mulBroadcastRhs -func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @powBroadcastRhs -func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @remainderBroadcastRhs -func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftLeftBroadcastRhs -func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs -func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightLogicalBroadcastRhs -func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @subBroadcastRhs -func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @andBroadcastRhs -func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @orBroadcastRhs -func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @xorBroadcastRhs -func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { @@ -229,63 +9,3 @@ func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> } - -// ----- - -// CHECK-LABEL: @compareBroadcastRhs -func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%arg0, %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1> - return %0 : tensor<1x4xi1> -} - -// ----- - -// CHECK-LABEL: @dynamicCompareBroadcastRhs -func @dynamicCompareBroadcastRhs(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor, tensor) -> tensor - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAdd -func @dynamicBroadcastAdd(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[DIM1:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAddScalar -func @dynamicBroadcastAddScalar(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0]], %[[DIM1]]) : (index, index) -> tensor<2xindex> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 15fa91588a5..20b43e8633d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure // CHECK: HloModule func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { @@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // ----- -// CHECK: HloModule -func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { - // Same rank degenerate broadcast - // CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]]) - // CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]]) - // CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1) - // CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]]) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - - // Broadcast up rank - // CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2} - // CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2) - // CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]]) - %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - - // Broadcast up rank + degenerate broadcast - // CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2} - // CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]]) - // CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2} - // CHECK: ROOT - // CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]]) - %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - return %2 : tensor<2x3x4xi32> -} - -// ----- - // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 207a8f2eabc..af45f84b34d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure HloModule main @@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } -// This test is more thorough than those of the the other binary ops to test -// their shared functionality. - -// CHECK-LABEL: func @test_add -%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] { - %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[4] parameter(1) - %Arg_2.3 = f32[] parameter(2) - %Arg_3.4 = f32[] parameter(3) - - // Add two tensors - // CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} - %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - - // Add two scalars - // CHECK-NEXT: xla_hlo.add %arg2, %arg3 - %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) - - // Add a tensor and scalar - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) -} - // CHECK-LABEL: func @test_after_all // CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token %test_after_all (token0: token[], token1: token[] ) -> token[] { @@ -159,11 +136,11 @@ add { } -// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> { -%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { +// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> { +%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] { %Arg_0.1 = f32[3] parameter(0) %Arg_1.2 = f32[3] parameter(1) - %Arg_2.3 = f32[1] parameter(2) + %Arg_2.3 = f32[3] parameter(2) // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ @@ -172,7 +149,7 @@ add { %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -280,19 +257,19 @@ add { ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } -// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf64> { -%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { +// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> { +%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] { %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[] parameter(1) + %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor) -> tensor - %convert.4 = f64[] convert(f32[] %Arg_1.2) + // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + %convert.4 = f64[4] convert(f32[4] %Arg_1.2) - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) + // CHECK-NEXT: xla_hlo.add %0, %1 + ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4) } // CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc index 0c9585a817f..e5a79616d5b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -163,8 +163,7 @@ struct HloBinaryElementwiseAdaptor { Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr); + broadcasted_lhs, broadcasted_rhs); } }; @@ -183,9 +182,9 @@ struct HloCompareAdaptor { Type result_type, Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { - return builder.create( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr, from_op.comparison_direction()); + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 10bac232b0f..8675d6c8a4b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -67,8 +67,9 @@ class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { allow_partial_conversion_ = allow_partial_conversion; + legalize_chlo_ = legalize_chlo; } /// Performs the lowering to XLA dialect. @@ -79,6 +80,11 @@ class LegalizeTF : public PassWrapper { *this, "allow-partial-conversion", llvm::cl::desc("Allow operations that can't be legalized."), llvm::cl::init(false)}; + Option legalize_chlo_{ + *this, "legalize-chlo", + llvm::cl::desc( + "Also legalizes intermediate chlo ops to hlo (default true)"), + llvm::cl::init(true)}; }; /// Returns if the given TF data format string is the default format. @@ -362,6 +368,154 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder); } +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return RankedTensorType::get(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = x.getType().cast(); + auto y_type = y.getType().cast(); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, broadcast_dims); + } + } + return builder.create(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return RankedTensorType::get({dim}, b.getIndexType()); +} + +// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value' +// by assuming that the shape of the lower ranked is a broadcast compatible +// prefix of the higher ranked. +// Values must be RankedTensorType (this restriction derives from the +// broadcast_dimensions attribute on DynamicBroadcastInDim). +// +// Example: +// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the +// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style +// implicit broadcasting). +static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value, + Value lower_rank_value, OpBuilder &builder) { + Value higher_rank_shape = + builder.create(loc, higher_rank_value); + auto result_extents_type = + GetExtentsTensorTypeFor(higher_rank_value.getType().cast()); + Value result_extents = builder.create( + loc, result_extents_type, higher_rank_shape); + + auto lower_rank_type = lower_rank_value.getType().cast(); + auto lower_rank = lower_rank_type.getRank(); + auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder); + return builder.create( + loc, higher_rank_value.getType(), lower_rank_value, result_extents, + prefix_dims); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = broadcast_to.getType().cast(); + auto result_shape = builder.create(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + return builder.create( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + // Creates a batch dot using xla_hlo::DotGeneralOp. Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, bool transpose_rhs, int64_t num_batch_dims, @@ -407,8 +561,7 @@ static void BuildReduceBody(Type element_type, Region *body, Location loc = body->getLoc(); auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr); + builder->create(loc, block->getArgument(0), block->getArgument(1)); builder->create(loc, reducer.getResult()); } @@ -508,8 +661,7 @@ static void CreateWhile32(Location loc, int num_iterations, loc, builder->getI32IntegerAttr(num_iterations)); StringAttr compare_direction = StringAttr::get("LT", builder->getContext()); Value compare = builder->create( - loc, loop_iv, upper_limit, - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, loop_iv, upper_limit, compare_direction); builder->create(loc, compare); } @@ -539,9 +691,9 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create(loc, builder->getI32IntegerAttr(1)); - auto no_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create(loc, old_values[0], one, - no_broadcast_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto plus_one = builder->create( + loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); @@ -566,21 +718,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, GetFeatureDimension(format, input.getType().cast())); } -//===----------------------------------------------------------------------===// -// Bias op utilities. -//===----------------------------------------------------------------------===// - -// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. -// Requires input to have ranked tensor. -static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, - StringAttr format, - Value input) { - auto inputType = input.getType().cast(); - size_t featureDim = GetFeatureDimension(format, inputType); - RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, featureDim); -} - //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -743,8 +880,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(2), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(2), compare_direction); Value selected_input = builder->create( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); @@ -860,8 +996,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(1), compare_direction); builder->create(loc, compare); } @@ -900,6 +1035,27 @@ NamedAttribute GetConvDimensionNumbersAttr( feature_dim, spatial_dims, builder->getContext())); } +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto feature_dim = GetFeatureDimension( + op.data_formatAttr(), op.value().getType().cast()); + auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), + feature_dim, rewriter); + rewriter.replaceOpWithNewOp(op, op.value(), bias_broadcast); + return success(); + } +}; + // Converts the TensorFlow conv op in template to the generic HLO conv op by // converting TensorFlow op attributes to HLO op attributes. // @@ -1161,7 +1317,6 @@ class ConvertDiagPartOp : public OpRewritePattern { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); @@ -1274,33 +1429,35 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter); - auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create(loc, var, epsilon.getResult(), - no_broadcast_dims); + auto add_op = rewriter.create( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create(loc, act, mean, broadcast_dims); - auto weighted_grad = - rewriter.create(loc, grad, sub_op, no_broadcast_dims); + auto sub_op = rewriter.create( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.scale(), scratch1, no_broadcast_dims); - x_backprop = - rewriter.create(loc, grad, scaled_grad, broadcast_dims); + rewriter.create(loc, op.scale(), scratch1); + x_backprop = rewriter.create( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = - rewriter.create(loc, scratch1, scratch2, no_broadcast_dims); + scale_backprop = rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); @@ -1396,7 +1553,7 @@ class ConvertFusedBatchNormV3Op auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - Value corrected_variance = rewriter.create( + Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); @@ -1416,24 +1573,26 @@ class ConvertFusedBatchNormV3Op rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create( + auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.mean().getType(), alpha, op.mean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_mean = rewriter.create( + auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); - batch_mean = rewriter.create( + batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create( + auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.variance().getType(), alpha, op.variance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_variance = rewriter.create( - op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); - corrected_variance = rewriter.create( + auto beta_mul_batch_variance = + rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, + corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, /*broadcast_dimensions=*/DenseIntElementsAttr()); } @@ -1586,10 +1745,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Divide by the number of elements in the window. Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto batch_dims = - GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter); - Value result = rewriter.create(op.getLoc(), result_type, reduce, - divisor, batch_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + Value result = rewriter.create( + op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. if (input_element_type != sum_element_type) @@ -1759,16 +1917,14 @@ class ConvertSigmoidOp : public OpRewritePattern { op.getLoc(), type, scalar_one, GetI64ElementsAttr(type.getShape(), &rewriter)); - auto scaled_input = rewriter.create( - op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); + auto scaled_input = + rewriter.create(op.getLoc(), operand, constant_ones); auto tanh_op = rewriter.create(op.getLoc(), operand.getType(), scaled_input); auto mul_op = - rewriter.create(op.getLoc(), tanh_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + rewriter.create(op.getLoc(), tanh_op, constant_ones); auto add_op = - rewriter.create(op.getLoc(), mul_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + rewriter.create(op.getLoc(), mul_op, constant_ones); rewriter.replaceOp(op, add_op.getResult()); return success(); @@ -1807,20 +1963,18 @@ class ConvertSoftmaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value logits = op.logits(); - // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. + // Note that the input and output shape is equivalent, so we use 'logits' + // and its type for shape calculations. + Value logits = op.logits(); RankedTensorType type = logits.getType().dyn_cast(); if (!type) return failure(); - auto loc = op.getLoc(); int rank = type.getRank(); // Note that the TensorFlow Softmax op verifies that the input rank is - // greater than or equal to one so both of the following sequences are - // valid. - auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); + // greater than or equal to one so the following sequence is valid. auto reduce_dim = rewriter.create( loc, GetI64ElementsAttr({rank - 1}, &rewriter)); @@ -1833,8 +1987,10 @@ class ConvertSoftmaxOp : public OpRewritePattern { auto max_logits = rewriter.create(loc, logits, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - auto shifted_logits = - rewriter.create(loc, type, logits, max_logits, batch_dims); + auto max_logits_broadcast = + CommonPrefixBroadcast(loc, logits, max_logits, rewriter); + auto shifted_logits = rewriter.create(loc, type, logits, + max_logits_broadcast); // Exponentiate the inputs. Value exp = rewriter.create(loc, type, shifted_logits); @@ -1847,9 +2003,12 @@ class ConvertSoftmaxOp : public OpRewritePattern { if (use_log) { Value log = rewriter.create(loc, sum); - rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims); + auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter); + rewriter.replaceOpWithNewOp(op, shifted_logits, + log_broadcast); } else { - rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); + auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter); + rewriter.replaceOpWithNewOp(op, exp, sum_broadcast); } return success(); } @@ -1896,7 +2055,7 @@ class ConvertSizeOp : public OpRewritePattern { auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -2582,10 +2741,10 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2633,7 +2792,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( + auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( @@ -2641,11 +2800,11 @@ class ConvertLinSpaceOp : public OpRewritePattern { if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( + step_denominator = rewriter.create( op.getLoc(), step_denominator.getType(), step_denominator, one, xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create( + auto step = rewriter.create( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); @@ -2653,10 +2812,10 @@ class ConvertLinSpaceOp : public OpRewritePattern { // Scale the iota and add the offset. auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2732,8 +2891,8 @@ class GenericConvertReductionOp : public OpRewritePattern { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create(loc, result, divisor.getResult(), - broadcast_dims); + result = rewriter.create( + loc, result, divisor.getResult(), broadcast_dims); } result = rewriter.create(loc, result, element_type); @@ -3118,7 +3277,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { auto reducer = rewriter.create( loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, StringAttr::get("GE", rewriter.getContext())); rewriter.create(loc, reducer.getResult()); } @@ -3544,13 +3702,20 @@ class ConvertOneHotOp : public OpRewritePattern { output_dims.insert(output_dims.begin() + axis, depth); Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. auto index_type = RankedTensorType::get(output_dims, element_type); - Value compare = rewriter.create( - loc, op.indices(), - rewriter.create( - loc, index_type, - IntegerAttr::get(rewriter.getIntegerType(64), axis)), - GetI64ElementsAttr(broadcast_dims, &rewriter), + auto iota = rewriter.create( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create( + loc, index_type, op.indices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create( + loc, broadcast_indices, iota, StringAttr::get("EQ", rewriter.getContext())); Value on_value = rewriter.create( loc, op.getType(), op.on_value(), @@ -4396,7 +4561,6 @@ class ConvertQrOp : public OpRewritePattern { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value identity_matrix = rewriter.create(op.getLoc(), compare, type.getElementType()); @@ -4430,8 +4594,7 @@ class ConvertQrOp : public OpRewritePattern { batch_dims.size(), precision_config, &rewriter); a_update = BatchDot(op.getLoc(), y, false, a_update, false, batch_dims.size(), precision_config, &rewriter); - a_panel = rewriter.create(op.getLoc(), a_panel, a_update, - /*broadcast_dimensions=*/nullptr); + a_panel = rewriter.create(op.getLoc(), a_panel, a_update); a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k}, &rewriter); @@ -4442,8 +4605,7 @@ class ConvertQrOp : public OpRewritePattern { batch_dims.size(), precision_config, &rewriter); q_update = BatchDot(op.getLoc(), q_update, false, y, true, batch_dims.size(), precision_config, &rewriter); - q_panel = rewriter.create(op.getLoc(), q_panel, q_update, - /*broadcast_dimensions=*/nullptr); + q_panel = rewriter.create(op.getLoc(), q_panel, q_update); q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter); } // full_matrices is false when only a partial result in needed. Slice to the @@ -4505,34 +4667,31 @@ class ConvertQrOp : public OpRewritePattern { Value iota = builder->create( loc, RankedTensorType::get({m}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value gtk = builder->create( + Value gtk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("GT", builder->getContext())); gtk = builder->create(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create( + Value x_after_k = builder->create( loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); - Value x_after_k_sq = builder->create( - loc, x_after_k, x_after_k, /*broadcast_dimensions=*/nullptr); + Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); // sigma = np.dot(x[k+1:], x[k+1:]) auto sigma = builder->create( loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder)); BuildReduceBody(x_type.getElementType(), &sigma.body(), builder); // mu = np.sqrt(x[k]*x[k] + sigma) - Value alpha_sq = builder->create(loc, alpha, alpha, - /*broadcast_dimensions=*/nullptr); + Value alpha_sq = builder->create(loc, alpha, alpha); Value mu = builder->create( - loc, builder->create(loc, alpha_sq, sigma.getResult(0), - /*broadcast_dimensions=*/nullptr)); + loc, builder->create(loc, alpha_sq, sigma.getResult(0))); - Value sigma_is_zero = builder->create( + Value sigma_is_zero = builder->create( loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); - Value alpha_is_negative = builder->create( + Value alpha_is_negative = builder->create( loc, alpha, zero, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); auto batch_size_one = builder->create( loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create( + Value signed_mu = builder->create( loc, builder->create(loc, mu.getType(), alpha_is_negative, batch_size_one, @@ -4541,21 +4700,16 @@ class ConvertQrOp : public OpRewritePattern { *beta = builder->create(loc, alpha.getType(), sigma_is_zero, alpha, signed_mu); *tau = builder->create( - loc, - builder->create(loc, *beta, alpha, - /*broadcast_dimensions=*/nullptr), - *beta, - /*broadcast_dimensions=*/nullptr); + loc, builder->create(loc, *beta, alpha), *beta); Value zero_tau = builder->create( loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder)); *tau = builder->create(loc, alpha.getType(), sigma_is_zero, zero_tau, *tau); - Value divisor = builder->create(loc, alpha, *beta, - /*broadcast_dimensions=*/nullptr); + Value divisor = builder->create(loc, alpha, *beta); divisor = builder->create(loc, divisor.getType(), sigma_is_zero, batch_size_one, divisor); - Value eqk = builder->create( + Value eqk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); eqk = builder->create(loc, eqk, x_type.getElementType()); @@ -4568,10 +4722,12 @@ class ConvertQrOp : public OpRewritePattern { // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = builder->create( + // Note that the add performs a degenerate broadcast. + *v = builder->create( loc, e_k, - builder->create(loc, x_after_k, divisor, - GetI64ElementsAttr(batch_dim_ids, builder)), + StaticBinaryBroadcast(loc, x_after_k, divisor, + GetI64ElementsAttr(batch_dim_ids, builder), + *builder), /*broadcast_dimensions=*/nullptr); } @@ -4645,10 +4801,10 @@ class ConvertQrOp : public OpRewritePattern { precision, builder); vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, precision, builder); - auto tau_x_vva = builder->create( - loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder)); - a = builder->create(loc, a, tau_x_vva, - /*broadcast_dimensions=*/nullptr); + auto tau_x_vva = StaticBinaryBroadcast( + loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); + a = builder->create(loc, a, tau_x_vva); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. @@ -4657,12 +4813,12 @@ class ConvertQrOp : public OpRewritePattern { auto iota = builder->create( loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create( + Value predecessor_mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); predecessor_mask = builder->create(loc, predecessor_mask, a_type.getElementType()); - Value mask = builder->create( + Value mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); mask = builder->create(loc, mask, a_type.getElementType()); @@ -4674,14 +4830,14 @@ class ConvertQrOp : public OpRewritePattern { mask, GetI64ElementsAttr(llvm::SmallVector(num_batch_dims, 1), builder)); - Value predecessor_masked_x = builder->create( + Value predecessor_masked_x = StaticBinaryBroadcast( loc, x, predecessor_mask, - GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder)); - Value masked_beta = builder->create( - loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder)); + GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder); + Value masked_beta = StaticBinaryBroadcast( + loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); Value new_x = - builder->create(loc, predecessor_masked_x, masked_beta, - /*broadcast_dimensions=*/nullptr); + builder->create(loc, predecessor_masked_x, masked_beta); // Update a[:,j] llvm::SmallVector dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); @@ -4692,7 +4848,7 @@ class ConvertQrOp : public OpRewritePattern { loc, RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create( + Value xa_mask = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); a = builder->create(loc, a_type, xa_mask, new_x, a); @@ -4708,11 +4864,11 @@ class ConvertQrOp : public OpRewritePattern { builder)); auto vs_update = builder->create( loc, vs.getType(), xa_mask, - builder->create( - loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder)), + StaticBinaryBroadcast( + loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), + *builder), vs_zeros); - vs = builder->create(loc, vs, vs_update, - /*broadcast_dimensions=*/nullptr); + vs = builder->create(loc, vs, vs_update); // taus[j] = tau llvm::SmallVector tau_broadcast_dims(batch_dims.size()); @@ -4729,17 +4885,16 @@ class ConvertQrOp : public OpRewritePattern { loc, taus.getType(), taus_zeros, GetI64ElementsAttr(taus.getType().cast().getShape(), builder)); - Value taus_mask = builder->create( + Value taus_mask = builder->create( loc, iota_n, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); auto taus_update = builder->create( loc, taus.getType(), taus_mask, - builder->create( + StaticBinaryBroadcast( loc, taus_zeros, tau, - GetI64ElementsAttr(tau_broadcast_dims, builder)), + GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), taus_zeros); - taus = builder->create(loc, taus, taus_update, - /*broadcast_dimensions=*/nullptr); + taus = builder->create(loc, taus, taus_update); new_values->assign({a, vs, taus}); }; @@ -4796,8 +4951,7 @@ class ConvertQrOp : public OpRewritePattern { j = builder->create( loc, j, GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1, - builder), - /*broadcast_dimensions=*/nullptr); + builder)); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder); // beta has shape [..., 1] @@ -4816,7 +4970,7 @@ class ConvertQrOp : public OpRewritePattern { loc, vs.getType(), zero, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); - auto compare = builder->create( + auto compare = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("GE", builder->getContext())); auto y = builder->create(loc, vs.getType(), compare, zero, vs); @@ -4831,13 +4985,12 @@ class ConvertQrOp : public OpRewritePattern { // z = -beta * (v + wyv) auto neg_beta = builder->create(loc, beta); - auto v_wyv = builder->create(loc, v, wyv, - /*broadcast_dimensions=*/nullptr); + auto v_wyv = builder->create(loc, v, wyv); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto z = builder->create( + auto z = StaticBinaryBroadcast( loc, neg_beta, v_wyv, - GetI64ElementsAttr(beta_broadcast_dims, builder)); + GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter); w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder); new_values->assign({w, vs, taus}); @@ -4855,8 +5008,9 @@ class ConvertQrOp : public OpRewritePattern { auto neg_beta = rewriter->create(loc, beta); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto bv = rewriter->create( - loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter)); + auto bv = StaticBinaryBroadcast( + loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter), + *rewriter); w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter); SmallVector while_output; @@ -4912,7 +5066,8 @@ void EmitLegalizationErrors(Operation *op, // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) + if (failed( + legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) signalPassFailure(); } @@ -4923,7 +5078,8 @@ static PassRegistration pass( #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { +LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, + bool legalize_chlo) { MLIRContext *context = op->getContext(); // Add lowering patterns to the list. @@ -4936,19 +5092,19 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp, - ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, - ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, - ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp, - ConvertEinsumOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, - ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, - ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, - ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, - ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp, + ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, + ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, + ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, @@ -4959,10 +5115,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. - xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + if (legalize_chlo) { + xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + } ConversionTarget target(*context); - target.addIllegalDialect(); + if (legalize_chlo) { + target.addIllegalDialect(); + } else { + target.addLegalDialect(); + } target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); @@ -4988,8 +5150,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { } std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion) { - return std::make_unique(allow_partial_conversion); + bool allow_partial_conversion, bool legalize_chlo) { + return std::make_unique(allow_partial_conversion, legalize_chlo); } } // end namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 959902692dc..33c92ee65d5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -73,21 +73,6 @@ def : Pattern< // HLO and XLA doesn't support Assertions. def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; -//===----------------------------------------------------------------------===// -// Bias op patterns. -//===----------------------------------------------------------------------===// -def BiasAddFeatureDimension : NativeCodeCall< - "getBiasFeatureDimension($_builder, $0, $1)">; - -// $input needs to be a ranked tensor to identify index of the feature -// dimension depending on the data_format 'NHWC' or 'NCHW'. -// TODO(laurenzo): This should be converted to do explicit broadcasting since -// it can generate broadcast dimensions that are not compatible with the simple -// xla_chlo.add broadcast_dims. -def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), - (HLO_AddOp $input, $bias, - (BiasAddFeatureDimension $data_format, $input))>; - //===----------------------------------------------------------------------===// // Binary op patterns. //===----------------------------------------------------------------------===// @@ -114,7 +99,8 @@ foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], def LowerRightShiftSigned : Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $r)]>; // TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op @@ -126,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))), + (HLO_FloorOp + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; -// Performs a substitution of FloorDir for integer tensors, which required +// Performs a substitution of FloorDiv for integer tensors, which required // additional correction for a negative numerator / denominator. Equivalent // pseudocode is shown below: // @@ -150,16 +137,16 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // broadcast attributes. def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp - (HLO_CompareOp - (HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_DivOp - (HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l), - (HLO_SubOp (HLO_AbsOp $r), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastSubOp (HLO_AbsOp $r), (HLO_ConstOp (ConstantSplat<"1"> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), @@ -175,20 +162,20 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), // broadcast attributes. def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastAndOp + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), (HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)), (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE), - (HLO_CompareOp - (HLO_CompareOp:$r_cmp $r, + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp:$r_cmp $r, (HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp:$rem_cmp $rem, $r_zeros, + (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLO_AddOp $r, + (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// @@ -406,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT ), $m_dim, $num_lower ), (HLO_SelectOp:$num_upper_or_n (HLO_CompareOp - $num_upper, $zero, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT ), $n_dim, $num_upper ), (HLO_SelectOp (HLO_AndOp - (HLO_CompareOp + (HLOClient_BroadcastCompareOp (HLO_NegOp (createConvertOp $op, $num_lower_or_m, $input) ), (HLO_SubOp:$offset - (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input), - (NullDenseIntElementsAttr) + (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input) ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ), - (HLO_CompareOp + (HLOClient_BroadcastCompareOp $offset, (createConvertOp $op, $num_upper_or_n, $input ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE - ), - (BinBroadcastDimensions $offset, $input) + ) ), $input, (HLO_ConstOp (ConstantSplat<"0"> $input)) @@ -462,8 +446,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), // TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), - (HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, - (BinBroadcastDimensions $zero, $input)), + (HLOClient_BroadcastMaxOp + (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), [(TF_SintOrFpTensor $input)]>; // TODO(hinsu): Lower unsigned and quantized types after supporting @@ -485,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input), // to create splat tensor of dynamic shape in HLO. def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (HLO_SelectOp - (HLO_CompareOp $features, + (HLOClient_BroadcastCompareOp $features, (HLO_ConstOp (GetScalarOfType<0> $features)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; @@ -598,7 +583,6 @@ def : Pat<(TF_SignOp $x), (HLO_CompareOp $x, $x, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_NE ), (HLO_ConstOp (ConstantSplat<"0"> $x)), @@ -641,8 +625,6 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), //===----------------------------------------------------------------------===// def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_MulOp - (HLO_MulOp $r, $l, (NullDenseIntElementsAttr)), - (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr)), + (HLO_MulOp $r, $l), + (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l)), [(IEEEFloatTensor $l)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index c0f6c2c3541..21e39db018b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -36,47 +36,36 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint; -def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (AddFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r), (SubFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (MulFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (DivFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (RemFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedDivIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedRemIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td index dcb0ab20e9e..e1ae5ef6abf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" // and imaginary components. foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), - $broadcast_dimensions), - (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), - $broadcast_dimensions))>; + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; // Complex multiplication results in a cross product multiplication between the // real and imaginary components such that: // result.real = lhs.real * rhs.real - lhs.imag * rhs.imag // result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp (HLO_SubOp (HLO_MulOp (HLO_RealOp:$lhs_real $lhs), - (HLO_RealOp:$rhs_real $rhs), - $broadcast_dimensions), + (HLO_RealOp:$rhs_real $rhs)), (HLO_MulOp (HLO_ImagOp:$lhs_imag $lhs), - (HLO_ImagOp:$rhs_imag $rhs), - $broadcast_dimensions), - (NullDenseIntElementsAttr)), + (HLO_ImagOp:$rhs_imag $rhs))), (HLO_AddOp - (HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions), - (HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions), - (NullDenseIntElementsAttr)))>; + (HLO_MulOp $lhs_real, $rhs_imag), + (HLO_MulOp $lhs_imag, $rhs_real)))>; // Multiplication between a complex and real tensor can be distributed by // applying the real multiplicant to both the real and complex component. // // Note that the sourcep pattern is not legal according to the HLO dialect but // instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_MulOp (HLO_RealOp $lhs), $rhs), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>; + (HLO_MulOp $lhs, (HLO_RealOp $rhs)), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_DivOp (HLO_MulOp:$num $lhs, (HLO_ComplexOp:$conj (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs))), - $broadcast_dimensions), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)), - (BinBroadcastDimensions $num, $den))>; + (HLO_NegOp (HLO_ImagOp $rhs)))), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_DivOp (HLO_RealOp $lhs), $rhs), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; // Absolute value is evaluated as: @@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_SqrtOp (HLO_AddOp - (HLO_MulOp (HLO_RealOp:$real $val), $real, - (NullDenseIntElementsAttr)), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr))), + (HLO_MulOp (HLO_RealOp:$real $val), $real), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), (HLO_ConstOp (ConstantSplat<"0"> $real)))>; // Exponential can be lowered to an exponential on the real component and a @@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_ExpOp (HLO_RealOp $val)), (HLO_ComplexOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)), - (NullDenseIntElementsAttr))>; + (HLO_SinOp $imag)))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 7b4262825f8..c56f5adc12d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -28,259 +28,6 @@ namespace xla_hlo { namespace { -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; - - SmallVector vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - -// Helper function for OpRewritePattern classes to materialize broadcasts on -// LHS and RHS arguments to a binary op. -// -// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful, -// returns false otherwise. -template -bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - // - // If the higher dimensional argument does not actually need the broadcast, - // a canonicalization pass should be able to remove that op later. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto op_ranked_type = op.getType().template dyn_cast(); - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - // Dynamic result shape, can't use BroadcastInDimOp. - assert(op_ranked_type.hasStaticShape() && - "dynamic shape requires DynamicBroadcastInDim"); - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - ArrayRef op_shape = op_ranked_type.getShape(); - - // BroadcastInDimOp must have the same element type for operands and results, - // so preserve the original output shape and the original input element type. - // For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`: - // broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32> - // broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32> - // SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) { - auto type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter); - if (lhs_rank < rhs_rank) { - attr = op.broadcast_dimensions().getValue(); - } - - lhs = - rewriter->createOrFold(op.getLoc(), type, lhs, attr); - } - - if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) { - auto type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter); - if (rhs_rank < lhs_rank) { - attr = op.broadcast_dimensions().getValue(); - } - - rhs = - rewriter->createOrFold(op.getLoc(), type, rhs, attr); - } - - *out_lhs = lhs; - *out_rhs = rhs; - return true; -} - -// Helper template to generate code for computing the result shape of a -// broadcasted operation. This ultimately should be subsumed by functions -// from the shape dialect. -// Assumes that large and small are the operand values of `op` and that they -// have a ranked tensory type with rank(large) >= rank(small). -template -std::vector ComputeBroadcastedShape(SrcOp op, Value small, Value large, - PatternRewriter *rewriter) { - auto loc = op.getLoc(); - auto larger_ranked_type = large.getType().cast(); - auto output_rank = larger_ranked_type.getRank(); - - constexpr int kExpandShape = -1; - - std::vector shape_values; - shape_values.reserve(output_rank); - std::vector indexes(output_rank, kExpandShape); - DenseIntElementsAttr broadcast_dimensions = - op.broadcast_dimensions().getValue(); - // Compute a mapping from output dimensions to their corresponding input - // dimensions in the smaller ranked operand. - for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - indexes.at(pair.value().getLimitedValue()) = pair.index(); - } - - // Compute the broadcasted shape of the result using numpy style broadcasting - // semantics. The result shape at a position is the shape of the larger - // operand at that position if the no dimension of the smaller operand is - // mapped to it. - // If both operands contribute to an output dimension, their shape has to - // either be the same in that dimension or it can be 1, in which case the - // shape of the other operand is used. - for (int i = 0; i < output_rank; ++i) { - if (indexes[i] == kExpandShape) { - // The smaller shape gets expanded to the larger one in this case. - shape_values.push_back(rewriter->create(loc, large, i)); - continue; - } - // Compute the result shape depending on whether the rank of smaller is 1. - // This does not check that the broadcast operation actualy is correct. - // In particular, we do not check that both shapes are the same if the - // smaller ranked shape is not 1. - ConstantOp one = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1)); - DimOp lrg_dim = rewriter->create(loc, large, i); - DimOp sml_dim = rewriter->create(loc, small, indexes[i]); - CmpIOp compare = - rewriter->create(loc, CmpIPredicate::eq, lrg_dim, one); - shape_values.push_back( - rewriter->create(loc, compare, lrg_dim, sml_dim)); - } - - return shape_values; -} - -// Helper function for OpRewritePattern classes to materialize dynamic -// broadcasts on LHS and RHS arguments to a binary op. -// -// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts -// for LHS and RHS arguments, else returns false. -template -bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - std::vector shape_elements; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - shape_elements = ComputeBroadcastedShape(op, rhs, lhs, rewriter); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - shape_elements = ComputeBroadcastedShape(op, lhs, rhs, rewriter); - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // DynamicBroadcastInDimOp preserves the element type but produces a tensor - // with unranked shape. The rank of the output is the length of the - // output shape argument. - SmallVector op_shape(shape_elements.size(), - RankedTensorType::kDynamicSize); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - // We need a way to turn a list of scalars into a vector. While Standard - // dialect does not have one, use the XLA_HLO variant. - int shape_size = shape_elements.size(); - Type shape_element_type = shape_elements.front().getType(); - Value shape_value = rewriter->create( - op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type), - shape_elements); - - *out_lhs = rewriter->createOrFold( - op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold( - op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims); - return true; -} - -template -bool CreateBroadcastForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - auto op_ranked_type = op.getType().template dyn_cast(); - if (!op_ranked_type) return false; - - if (op_ranked_type.hasStaticShape()) { - if (!CreateStaticBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } else { - if (!CreateDynamicBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } - return true; -} - -template -struct BinaryOpWithBroadcastConvert : public OpRewritePattern { - explicit BinaryOpWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - // Replace the original op with a new one that uses the new args. - // New args are broadcasts, so no dims are needed on the replacement op. - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr); - return success(); - } -}; - // Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays // must be the same shape. Alternatively, as a restricted form of broadcasting, // min and/or max can be a scalar of type T." @@ -322,63 +69,10 @@ struct ClampWithBroadcastConvert : public OpRewritePattern { } }; -// Specialized class for CompareOp, as it has an additional builder argument. -struct CompareWithBroadcastConvert : public OpRewritePattern { - explicit CompareWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr, - op.comparison_direction()); - return success(); - } -}; - } // namespace void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget) { -#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ - conversionTarget->addDynamicallyLegalOp([](OpType op) { \ - if (op.broadcast_dimensions().hasValue()) return false; \ - auto l = op.lhs().getType().cast(); \ - auto r = op.rhs().getType().cast(); \ - if (!l.hasRank() || !r.hasRank()) return false; \ - return l.getShape() == r.getShape(); \ - }); - - // Binary elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp); - - // Binary logical elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp); - - // CompareOp. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp); - -#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST - conversionTarget->addDynamicallyLegalOp([](ClampOp op) { return op.max().getType() == op.operand().getType() && op.min().getType() == op.operand().getType(); @@ -387,30 +81,10 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // Binary elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>( - context); - patterns->insert>(context); - patterns->insert>(context); - - // Binary logical elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - - // ClampOp. It can have a restricted form of broadcasting. + // ClampOp. This op has a special case where it accepts either same-shaped + // inputs or scalars (a restricted form of broadcasting). This makes the + // broadcast explicit. patterns->insert(context); - // CompareOp. Note the specialized class instead of using the template. - patterns->insert(context); } } // namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index b148eac4286..a1dd6c5ce1e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -36,7 +36,7 @@ namespace xla_hlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion = false); + bool allow_partial_conversion = false, bool legalize_chlo = true); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. @@ -50,7 +50,8 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false); +LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true); /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index d53aaee3701..98eb404e4d4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -135,8 +135,8 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create( - bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr); + Value stddev = rewriter.create(bn_op.getLoc(), + bn_op.variance(), epsilon); stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. @@ -160,13 +160,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset Value result = rewriter.create( - bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr); + bn_op.getLoc(), bn_op.operand(), broadcast_mean); result = rewriter.create(bn_op.getLoc(), result, - broadcast_scale, nullptr); + broadcast_scale); result = rewriter.create(bn_op.getLoc(), result, - broadcast_stddev, nullptr); - rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset, - nullptr); + broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, + broadcast_offset); return success(); }