Add int8 support for topk_v2.

PiperOrigin-RevId: 232784488
This commit is contained in:
Shashi Shekhar 2019-02-06 17:54:58 -08:00 committed by TensorFlower Gardener
parent 1d179685a7
commit 9e9f25fbe8
4 changed files with 28 additions and 1 deletions

View File

@ -242,7 +242,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_CAST, Register_CAST());

View File

@ -207,6 +207,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
output_values->data.uint8);
break;
case kTfLiteInt8:
TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
output_values->data.int8);
break;
case kTfLiteInt32:
TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
output_values->data.i32);

View File

@ -46,6 +46,10 @@ class TopKV2OpModel : public SingleOpModel {
PopulateTensor<uint8_t>(input_, data);
}
void SetInputInt8(std::initializer_list<int8_t> data) {
PopulateTensor<int8_t>(input_, data);
}
void SetInputInt32(std::initializer_list<int32_t> data) {
PopulateTensor<int32_t>(input_, data);
}
@ -66,6 +70,10 @@ class TopKV2OpModel : public SingleOpModel {
return ExtractVector<uint8_t>(output_values_);
}
std::vector<int8_t> GetValuesInt8() {
return ExtractVector<int8_t>(output_values_);
}
std::vector<int32_t> GetValuesInt32() {
return ExtractVector<int32_t>(output_values_);
}
@ -128,6 +136,14 @@ TEST(TopKV2OpTest, TypeUint8) {
EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250}));
}
TEST(TopKV2OpTest, TypeInt8) {
TopKV2OpModel m({2, 3}, TensorType_INT8, 2);
m.SetInputInt8({1, 2, 3, -126, 125, -24});
m.Invoke();
EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2}));
EXPECT_THAT(m.GetValuesInt8(), ElementsAreArray({3, 2, 125, -24}));
}
// Check that int32_t works.
TEST(TopKV2OpTest, TypeInt32) {
TopKV2OpModel m({2, 3}, TensorType_INT32, 2);

View File

@ -1333,6 +1333,11 @@ class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};