Add int8 support for topk_v2.
PiperOrigin-RevId: 232784488
This commit is contained in:
parent
1d179685a7
commit
9e9f25fbe8
@ -242,7 +242,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
|
||||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
|
@ -207,6 +207,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
|
||||
output_values->data.uint8);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
|
||||
output_values->data.int8);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
|
||||
output_values->data.i32);
|
||||
|
@ -46,6 +46,10 @@ class TopKV2OpModel : public SingleOpModel {
|
||||
PopulateTensor<uint8_t>(input_, data);
|
||||
}
|
||||
|
||||
void SetInputInt8(std::initializer_list<int8_t> data) {
|
||||
PopulateTensor<int8_t>(input_, data);
|
||||
}
|
||||
|
||||
void SetInputInt32(std::initializer_list<int32_t> data) {
|
||||
PopulateTensor<int32_t>(input_, data);
|
||||
}
|
||||
@ -66,6 +70,10 @@ class TopKV2OpModel : public SingleOpModel {
|
||||
return ExtractVector<uint8_t>(output_values_);
|
||||
}
|
||||
|
||||
std::vector<int8_t> GetValuesInt8() {
|
||||
return ExtractVector<int8_t>(output_values_);
|
||||
}
|
||||
|
||||
std::vector<int32_t> GetValuesInt32() {
|
||||
return ExtractVector<int32_t>(output_values_);
|
||||
}
|
||||
@ -128,6 +136,14 @@ TEST(TopKV2OpTest, TypeUint8) {
|
||||
EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250}));
|
||||
}
|
||||
|
||||
TEST(TopKV2OpTest, TypeInt8) {
|
||||
TopKV2OpModel m({2, 3}, TensorType_INT8, 2);
|
||||
m.SetInputInt8({1, 2, 3, -126, 125, -24});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2}));
|
||||
EXPECT_THAT(m.GetValuesInt8(), ElementsAreArray({3, 2, 125, -24}));
|
||||
}
|
||||
|
||||
// Check that int32_t works.
|
||||
TEST(TopKV2OpTest, TypeInt32) {
|
||||
TopKV2OpModel m({2, 3}, TensorType_INT32, 2);
|
||||
|
@ -1333,6 +1333,11 @@ class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user