diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 948bf779d6b..27b34c8f31f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -242,7 +242,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); - AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); diff --git a/tensorflow/lite/kernels/topk_v2.cc b/tensorflow/lite/kernels/topk_v2.cc index 444b01e7b2e..64973d7b860 100644 --- a/tensorflow/lite/kernels/topk_v2.cc +++ b/tensorflow/lite/kernels/topk_v2.cc @@ -207,6 +207,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, output_values->data.uint8); break; + case kTfLiteInt8: + TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32, + output_values->data.int8); + break; case kTfLiteInt32: TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, output_values->data.i32); diff --git a/tensorflow/lite/kernels/topk_v2_test.cc b/tensorflow/lite/kernels/topk_v2_test.cc index 108b8123666..0097ae2f9ae 100644 --- a/tensorflow/lite/kernels/topk_v2_test.cc +++ b/tensorflow/lite/kernels/topk_v2_test.cc @@ -46,6 +46,10 @@ class TopKV2OpModel : public SingleOpModel { PopulateTensor(input_, data); } + void SetInputInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInputInt32(std::initializer_list data) { PopulateTensor(input_, data); } @@ -66,6 +70,10 @@ class TopKV2OpModel : public SingleOpModel { return ExtractVector(output_values_); } + std::vector GetValuesInt8() { + return ExtractVector(output_values_); + } + std::vector GetValuesInt32() { return ExtractVector(output_values_); } @@ -128,6 +136,14 @@ TEST(TopKV2OpTest, TypeUint8) { EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); } +TEST(TopKV2OpTest, TypeInt8) { + TopKV2OpModel m({2, 3}, TensorType_INT8, 2); + m.SetInputInt8({1, 2, 3, -126, 125, -24}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2})); + EXPECT_THAT(m.GetValuesInt8(), ElementsAreArray({3, 2, 125, -24})); +} + // Check that int32_t works. TEST(TopKV2OpTest, TypeInt32) { TopKV2OpModel m({2, 3}, TensorType_INT32, 2); diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 58c4a8589b0..50242531e78 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1333,6 +1333,11 @@ class TopK_V2 : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } };