diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index d3594c30431..f56d2b2d473 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1224,32 +1224,30 @@ func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf // CHECK-LABEL: matrix_band_part // CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { - // CHECK: %[[M:.*]] = mhlo.constant dense<64> : tensor - // CHECK: %[[N:.*]] = mhlo.constant dense<64> : tensor + // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor + // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor, tensor, tensor) -> tensor + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor, tensor, tensor) -> tensor - // CHECK: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor + // CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor + // CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor) -> tensor - // CHECK: %[[E:.*]] = "mhlo.convert"(%[[B]]) : (tensor) -> tensor - // CHECK: %[[F:.*]] = "mhlo.negate"(%[[E]]) : (tensor) -> tensor + // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64> + // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64> + // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64> + // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xi64>) -> tensor<64x64xi1> - // CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> - // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> - // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> + // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor) -> tensor<64x64xi1> - // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> + // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> - // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> + // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> - // CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> - // CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) - // CHECK: return %[[R]] + // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) + // CHECK-DAG: return %[[R]] %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> return %0 : tensor<64x64xbf16> } @@ -1257,19 +1255,20 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK-LABEL: matrix_band_part_2 // CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<12x24x48xbf16> { - // CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16> - // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> - // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> + // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64> + // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64> + // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64> - // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> + // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xi64>) -> tensor<24x48xi1> - // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> - // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> + // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor) -> tensor<24x48xi1> + // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> - // CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> - // CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) - // CHECK: return %[[R]] + // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> + + // CHECK-DAG: %[[K:.*]] = "mhlo.broadcast_in_dim"(%[[J]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<24x48xi1>) -> tensor<12x24x48xi1> + // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[K]], %[[INPUT]], %[[ZERO2]]) + // CHECK-DAG: return %[[R]] %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor, tensor) -> tensor<12x24x48xbf16> return %0 : tensor<12x24x48xbf16> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index ff22e74f1c4..1703dadeb4c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -786,11 +786,11 @@ static int GetDimensionSizeFromEnd(Value input, int dim_from_end) { // dimension and last dimension, respectively). The element type of the // outputted RankedTensorType will match the element type of `input`. // Requires that `input` is a tensor. -static RankedTensorType Get2DTensorType(Value input) { +static RankedTensorType Get2DTensorType(Value input, Value num_lower) { // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last. int dim_0 = GetDimensionSizeFromEnd(input, 1); int dim_1 = GetDimensionSizeFromEnd(input, 0); - auto element_type = input.getType().cast().getElementType(); + auto element_type = num_lower.getType().cast().getElementType(); return RankedTensorType::get({dim_0, dim_1}, element_type); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 5baef3b4afd..d8baef14e62 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -365,7 +365,8 @@ class getIntegerAttr: NativeCodeCall< "$_builder.getI64IntegerAttr(" # x # ")">; class GetDimensionSizeFromEnd: NativeCodeCall< - "$_builder.getI64IntegerAttr(GetDimensionSizeFromEnd($0, " # dimFromEnd # "))" + "$_builder.getIntegerAttr(getElementTypeOrSelf($1.getType()), " + " GetDimensionSizeFromEnd($0, " # dimFromEnd # "))" >; // TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern @@ -373,7 +374,7 @@ class GetDimensionSizeFromEnd: NativeCodeCall< // cannot be inferred. class createIotaOp: NativeCodeCall< "$_builder.create($0.getOwner()->getLoc(), " - "Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">; + "Get2DTensorType($1, $2), $_builder.getI64IntegerAttr(" # dim # "))">; // This op needs to be created in C++ because the generated Convert Op has no // way to specify shape information as an input. In the MatrixBandPart op @@ -396,9 +397,10 @@ def createConvertOp: NativeCodeCall< // return (indicator ? input : zero_matrix) // // TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp. -def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_upper), - [(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"0"> $input)), - (HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"1"> $input)), +def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower, + $num_upper), + [(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"1"> $input, $num_lower)), + (HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"0"> $input, $num_upper)), (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), @@ -414,22 +416,17 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ $n_dim, $num_upper ), - (HLO_SelectOp + (TF_SelectV2Op (HLO_AndOp (HLOClient_BroadcastCompareOp - (HLO_NegOp - (createConvertOp $op, $num_lower_or_m, $input) - ), + (HLO_NegOp $num_lower_or_m), (HLO_SubOp:$offset - (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input) + (createIotaOp<"1"> $op, $input, $num_lower), + (createIotaOp<"0"> $op, $input, $num_lower) ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ), - (HLOClient_BroadcastCompareOp - $offset, - (createConvertOp - $op, $num_upper_or_n, $input - ), + (HLOClient_BroadcastCompareOp $offset, $num_upper_or_n, (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ) ), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a0a8457ee99..058891721db 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -848,6 +848,7 @@ tf_xla_py_test( size = "medium", timeout = "long", srcs = ["matrix_band_part_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip