Add bool type support in TFLite concatenation op kernel

PiperOrigin-RevId: 332814276
Change-Id: I70f1c2f57abc000f08e290a5106688bd419bc753
This commit is contained in:
Jaesung Chung 2020-09-21 03:11:55 -07:00 committed by TensorFlower Gardener
parent 830b72bc81
commit d146fd6acc
5 changed files with 59 additions and 7 deletions

View File

@ -720,14 +720,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
let arguments = (
ins TFL_VariadicTensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values,
[F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values,
I32Attr:$axis,
TFL_AFAttr:$fused_activation_function
);
let results = (outs
TFL_TensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output
[F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output
);
let hasOptions = 1;

View File

@ -1068,24 +1068,33 @@ func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2With3Tensors
// CHECK-LABEL: concat_v2_with_3_tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK-LABEL: concat_v2_i64_axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concat_v2_with_bool_type(%arg0: tensor<?x1xi1>, %arg1: tensor<?x1xi1>) -> tensor<?x2xi1> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %0) : (tensor<?x1xi1>, tensor<?x1xi1>, tensor<i32>) -> tensor<?x2xi1>
return %1 : tensor<?x2xi1>
// CHECK-LABEL: concat_v2_with_bool_type
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<?x1xi1>, tensor<?x1xi1>) -> tensor<?x2xi1>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -58,7 +58,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
input_type == kTfLiteInt32 || input_type == kTfLiteInt64);
input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
input_type == kTfLiteBool);
// Output dimensions will match input dimensions, except 'axis', which
// will be the sum of inputs
@ -172,6 +173,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16:
TF_LITE_CONCATENATION(int16_t);
break;
case kTfLiteBool:
TF_LITE_CONCATENATION(bool);
break;
default:
context->ReportError(context, "Type '%s' is not supported currently.",
TfLiteTypeGetName(output->type));

View File

@ -90,6 +90,15 @@ class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
}
};
class BoolConcatenationOpModel : public BaseConcatenationOpModel {
public:
using BaseConcatenationOpModel::BaseConcatenationOpModel;
void SetInput(int index, std::initializer_list<bool> data) {
PopulateTensor(index, data);
}
std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
};
TEST(ConcatenationOpTest, ThreeDimensionalOneInput) {
ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
/*num_inputs=*/1);
@ -447,5 +456,27 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
TEST(ConcatenationOpTest, BoolTypeOneInput) {
BoolConcatenationOpModel m0({TensorType_BOOL, {2, 1, 2}}, /*axis=*/1,
/*num_inputs=*/1);
m0.SetInput(0, {true, false, false, true});
m0.Invoke();
EXPECT_THAT(m0.GetOutput(), ElementsAreArray({true, false, false, true}));
}
TEST(ConcatenationOpTest, BoolTypeTwoInputs) {
BoolConcatenationOpModel m0(
{{TensorType_BOOL, {2, 1, 2}}, {TensorType_BOOL, {2, 3, 2}}},
/*axis=*/1, /*num_inputs=*/2, TensorType_BOOL);
m0.SetInput(0, {false, false, false, false});
m0.SetInput(1, {true, true, true, true, true, true, true, true, true, true,
true, true});
m0.Invoke();
EXPECT_THAT(
m0.GetOutput(),
ElementsAreArray({false, false, true, true, true, true, true, true, false,
false, true, true, true, true, true, true}));
}
} // namespace
} // namespace tflite

View File

@ -59,6 +59,14 @@ def make_concat_tests(options):
"fully_quantize": [False],
"quant_16x8": [False],
"dynamic_range_quantize": [True],
}, {
"base_shape": [[1, 3, 4, 3]],
"num_tensors": [6],
"axis": [1],
"type": [tf.bool],
"fully_quantize": [False],
"quant_16x8": [False],
"dynamic_range_quantize": [True],
}]
def get_shape(parameters, delta):