Add Int8 support for arg_min_max.
PiperOrigin-RevId: 232728637
This commit is contained in:
parent
0c4f5dfea4
commit
f11085ebf2
@ -80,13 +80,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Unkonwn input type: %d, only float32 and int types are supported",
|
||||
"Unknown input type: %d, only float32 and int types are supported",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -135,6 +136,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
|
||||
break;
|
||||
@ -150,6 +154,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
|
||||
break;
|
||||
|
@ -86,6 +86,28 @@ TEST(ArgMaxOpTest, GetMaxArgFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgUInt8) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgInt8) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgInt) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
|
@ -252,8 +252,12 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
|
||||
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
|
||||
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
|
||||
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
|
||||
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
|
||||
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
|
||||
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
|
||||
|
@ -1354,6 +1354,12 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
@ -1375,6 +1381,12 @@ class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user