Add bool type support in TFLite concatenation op kernel
PiperOrigin-RevId: 332814276 Change-Id: I70f1c2f57abc000f08e290a5106688bd419bc753
This commit is contained in:
parent
830b72bc81
commit
d146fd6acc
@ -720,14 +720,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
|
||||
let arguments = (
|
||||
ins TFL_VariadicTensorOf<
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values,
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values,
|
||||
I32Attr:$axis,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
@ -1068,24 +1068,33 @@ func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2With3Tensors
|
||||
// CHECK-LABEL: concat_v2_with_3_tensors
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2I64Axis
|
||||
// CHECK-LABEL: concat_v2_i64_axis
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concat_v2_with_bool_type(%arg0: tensor<?x1xi1>, %arg1: tensor<?x1xi1>) -> tensor<?x2xi1> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %0) : (tensor<?x1xi1>, tensor<?x1xi1>, tensor<i32>) -> tensor<?x2xi1>
|
||||
return %1 : tensor<?x2xi1>
|
||||
|
||||
// CHECK-LABEL: concat_v2_with_bool_type
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<?x1xi1>, tensor<?x1xi1>) -> tensor<?x2xi1>
|
||||
}
|
||||
|
||||
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -58,7 +58,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context,
|
||||
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
|
||||
input_type == kTfLiteInt32 || input_type == kTfLiteInt64);
|
||||
input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
|
||||
input_type == kTfLiteBool);
|
||||
|
||||
// Output dimensions will match input dimensions, except 'axis', which
|
||||
// will be the sum of inputs
|
||||
@ -172,6 +173,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_CONCATENATION(int16_t);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
TF_LITE_CONCATENATION(bool);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Type '%s' is not supported currently.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
|
@ -90,6 +90,15 @@ class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
|
||||
}
|
||||
};
|
||||
|
||||
class BoolConcatenationOpModel : public BaseConcatenationOpModel {
|
||||
public:
|
||||
using BaseConcatenationOpModel::BaseConcatenationOpModel;
|
||||
void SetInput(int index, std::initializer_list<bool> data) {
|
||||
PopulateTensor(index, data);
|
||||
}
|
||||
std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
|
||||
};
|
||||
|
||||
TEST(ConcatenationOpTest, ThreeDimensionalOneInput) {
|
||||
ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
|
||||
/*num_inputs=*/1);
|
||||
@ -447,5 +456,27 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
|
||||
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, BoolTypeOneInput) {
|
||||
BoolConcatenationOpModel m0({TensorType_BOOL, {2, 1, 2}}, /*axis=*/1,
|
||||
/*num_inputs=*/1);
|
||||
m0.SetInput(0, {true, false, false, true});
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(m0.GetOutput(), ElementsAreArray({true, false, false, true}));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, BoolTypeTwoInputs) {
|
||||
BoolConcatenationOpModel m0(
|
||||
{{TensorType_BOOL, {2, 1, 2}}, {TensorType_BOOL, {2, 3, 2}}},
|
||||
/*axis=*/1, /*num_inputs=*/2, TensorType_BOOL);
|
||||
m0.SetInput(0, {false, false, false, false});
|
||||
m0.SetInput(1, {true, true, true, true, true, true, true, true, true, true,
|
||||
true, true});
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(
|
||||
m0.GetOutput(),
|
||||
ElementsAreArray({false, false, true, true, true, true, true, true, false,
|
||||
false, true, true, true, true, true, true}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -59,6 +59,14 @@ def make_concat_tests(options):
|
||||
"fully_quantize": [False],
|
||||
"quant_16x8": [False],
|
||||
"dynamic_range_quantize": [True],
|
||||
}, {
|
||||
"base_shape": [[1, 3, 4, 3]],
|
||||
"num_tensors": [6],
|
||||
"axis": [1],
|
||||
"type": [tf.bool],
|
||||
"fully_quantize": [False],
|
||||
"quant_16x8": [False],
|
||||
"dynamic_range_quantize": [True],
|
||||
}]
|
||||
|
||||
def get_shape(parameters, delta):
|
||||
|
Loading…
Reference in New Issue
Block a user