diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir index 65285021fd4..107a668c0a7 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. @@ -14,18 +14,14 @@ func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] + // CHECK: return %[[RESULT]] : tensor %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor return %0 : tensor } @@ -35,18 +31,14 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> + // CHECK: return %[[RESULT]] : tensor> %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> return %0 : tensor> } @@ -58,16 +50,12 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor - // CHECK: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK: return %[[FINAL_RESULT]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: return %[[RESULT]] : tensor %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 3d270a52f48..2153258993a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -1,7 +1,7 @@ // Note that binary elementwise tests are run with chlo legalization enabled // (unlike the rest), since this is the primary use case for such ops and // verification of shapes and broadcasts is desired. -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s //===----------------------------------------------------------------------===// // Binary op legalizations. @@ -24,8 +24,13 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { // patterns unambiguous and more interesting (once broadcastable trait is // fixed upstream). func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1 + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -34,26 +39,26 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x // TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream // broadcastable bug is fixed (helps make the CHECK matching unambiguous) func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - // CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1 + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32> } // CHECK-LABEL: func @add_dynamic func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[CSTR_LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[CSTR_RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[CSTR_LHS_SHAPE]], %[[CSTR_RHS_SHAPE]] - // CHECK-NEXT: shape.assuming %[[WITNESS:.+]] - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -75,21 +80,21 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-LABEL: func @div_unranked func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { - // CHECK-NEXT: tf.Div + // CHECK: tf.Div %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor return %0: tensor } // CHECK-LABEL: func @maximum func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @minimum func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -195,25 +200,26 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-LABEL: func @equal_dynamic func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] - // CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor) { - // CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 - // CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0: tensor } // CHECK-LABEL: func @equal_broadcast func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } @@ -275,25 +281,26 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-LABEL: func @broadcast_greater func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } // CHECK-LABEL: func @greater_dynamic func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] - // CHECK-NEXT: shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE1:.+]] = shape.shape_of %arg1 - // CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE1]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc index 97afa9617c4..e5a79616d5b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -112,19 +112,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp // Compute result shape. auto loc = op.getLoc(); - - // Insert a constraint on the shapes being broadcastable and insert all - // future code into an assuming block reliant on the constraint. - Value lhs_shape = rewriter.create(loc, lhs); - Value rhs_shape = rewriter.create(loc, rhs); - auto broadcastable_cstr = - rewriter.create(loc, lhs_shape, rhs_shape); - auto assuming_op = rewriter.create( - loc, ArrayRef{result_type}, broadcastable_cstr.result()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&assuming_op.doRegion()); - int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, @@ -153,10 +140,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp rewriter.getI64TensorAttr(rhs_broadcast_dimensions)); // And generate the final non-broadcasted binary op. - Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs, - broadcasted_rhs, rewriter); - rewriter.create(loc, final_result); - rewriter.replaceOp(op, {assuming_op.getResult(0)}); + rewriter.replaceOp(op, {Adaptor::CreateOp(op, result_type, broadcasted_lhs, + broadcasted_rhs, rewriter)}); return success(); } };