diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 1b91c0dbe61..93e85d9ced0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -720,14 +720,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins TFL_VariadicTensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output ); let hasOptions = 1; diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 3a2a0a8b9d2..a576e559c00 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1068,24 +1068,33 @@ func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) - // CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } -func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { +func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> -// CHECK-LABEL: concatv2With3Tensors +// CHECK-LABEL: concat_v2_with_3_tensors // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } -func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { +func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> -// CHECK-LABEL: concatv2I64Axis +// CHECK-LABEL: concat_v2_i64_axis // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concat_v2_with_bool_type(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: concat_v2_with_bool_type +// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor, tensor) -> tensor +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 3f759228e04..01f7f9fcc48 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -58,7 +58,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || - input_type == kTfLiteInt32 || input_type == kTfLiteInt64); + input_type == kTfLiteInt32 || input_type == kTfLiteInt64 || + input_type == kTfLiteBool); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -172,6 +173,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt16: TF_LITE_CONCATENATION(int16_t); break; + case kTfLiteBool: + TF_LITE_CONCATENATION(bool); + break; default: context->ReportError(context, "Type '%s' is not supported currently.", TfLiteTypeGetName(output->type)); diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index 4e362598aae..5a36895d847 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -90,6 +90,15 @@ class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { } }; +class BoolConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, /*num_inputs=*/1); @@ -447,5 +456,27 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) { ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } +TEST(ConcatenationOpTest, BoolTypeOneInput) { + BoolConcatenationOpModel m0({TensorType_BOOL, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {true, false, false, true}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({true, false, false, true})); +} + +TEST(ConcatenationOpTest, BoolTypeTwoInputs) { + BoolConcatenationOpModel m0( + {{TensorType_BOOL, {2, 1, 2}}, {TensorType_BOOL, {2, 3, 2}}}, + /*axis=*/1, /*num_inputs=*/2, TensorType_BOOL); + m0.SetInput(0, {false, false, false, false}); + m0.SetInput(1, {true, true, true, true, true, true, true, true, true, true, + true, true}); + m0.Invoke(); + EXPECT_THAT( + m0.GetOutput(), + ElementsAreArray({false, false, true, true, true, true, true, true, false, + false, true, true, true, true, true, true})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/concat.py b/tensorflow/lite/testing/op_tests/concat.py index 3341f3c5d22..30d7b91fd77 100644 --- a/tensorflow/lite/testing/op_tests/concat.py +++ b/tensorflow/lite/testing/op_tests/concat.py @@ -59,6 +59,14 @@ def make_concat_tests(options): "fully_quantize": [False], "quant_16x8": [False], "dynamic_range_quantize": [True], + }, { + "base_shape": [[1, 3, 4, 3]], + "num_tensors": [6], + "axis": [1], + "type": [tf.bool], + "fully_quantize": [False], + "quant_16x8": [False], + "dynamic_range_quantize": [True], }] def get_shape(parameters, delta):