Add Int8 support for arg_min_max.

PiperOrigin-RevId: 232728637
This commit is contained in:
Shashi Shekhar 2019-02-06 12:57:53 -08:00 committed by TensorFlower Gardener
parent 0c4f5dfea4
commit f11085ebf2
4 changed files with 48 additions and 3 deletions

View File

@ -80,13 +80,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
switch (input->type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt32:
break;
default:
context->ReportError(
context,
"Unkonwn input type: %d, only float32 and int types are supported",
"Unknown input type: %d, only float32 and int types are supported",
input->type);
return kTfLiteError;
}
@ -135,6 +136,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
break;
case kTfLiteInt8:
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
@ -150,6 +154,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
break;
case kTfLiteInt8:
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;

View File

@ -86,6 +86,28 @@ TEST(ArgMaxOpTest, GetMaxArgFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgUInt8) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32,
TensorType_INT32);
model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgInt8) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32,
TensorType_INT32);
model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgInt) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);

View File

@ -252,8 +252,12 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());

View File

@ -1354,6 +1354,12 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};
@ -1375,6 +1381,12 @@ class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};