axis test cases for argminmax
This commit is contained in:
parent
f09580c836
commit
c3606c810a
@ -143,6 +143,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8, int8 and int32 are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} break;
|
||||
@ -161,10 +165,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8, int8 and int32 are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Only int32 and int64 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
@ -177,10 +188,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8, int8 and int32 are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} break;
|
||||
@ -192,14 +210,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8, int8 and int32 are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Only int32 and int64 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
@ -27,9 +27,9 @@ template <typename T>
|
||||
class ArgBaseOpModel : public SingleOpModel {
|
||||
public:
|
||||
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType output_type, TensorType index_output_type) {
|
||||
TensorType axis_type, TensorType output_type) {
|
||||
input_ = AddInput(input_type);
|
||||
axis_ = AddInput(TensorType_INT32);
|
||||
axis_ = AddInput(axis_type);
|
||||
output_ = AddOutput(output_type);
|
||||
}
|
||||
|
||||
@ -49,13 +49,11 @@ template <typename T>
|
||||
class ArgMaxOpModel : public ArgBaseOpModel<T> {
|
||||
public:
|
||||
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType output_type, TensorType index_output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
|
||||
index_output_type) {
|
||||
TensorType axis_type, TensorType output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
|
||||
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
|
||||
.Union());
|
||||
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||
}
|
||||
};
|
||||
@ -64,13 +62,11 @@ template <typename T>
|
||||
class ArgMinOpModel : public ArgBaseOpModel<T> {
|
||||
public:
|
||||
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType output_type, TensorType index_output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
|
||||
index_output_type) {
|
||||
TensorType axis_type, TensorType output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
|
||||
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
|
||||
.Union());
|
||||
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||
}
|
||||
};
|
||||
@ -142,7 +138,7 @@ TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) {
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
||||
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
|
||||
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT64);
|
||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
@ -152,6 +148,38 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgAxis64) {
|
||||
// Input Int32, Axis Int64, Output Int64
|
||||
ArgMaxOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||
TensorType_INT64, TensorType_INT64);
|
||||
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||
model1.Invoke();
|
||||
|
||||
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int32
|
||||
ArgMaxOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT32);
|
||||
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||
model2.Invoke();
|
||||
|
||||
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int64
|
||||
ArgMaxOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT64);
|
||||
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||
model3.Invoke();
|
||||
|
||||
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgFloat) {
|
||||
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
|
||||
TensorType_INT32, TensorType_INT32);
|
||||
@ -197,7 +225,7 @@ TEST(ArgMinOpTest, GetMinArgNegativeAxis) {
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgOutput64) {
|
||||
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
|
||||
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT64);
|
||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
@ -207,6 +235,38 @@ TEST(ArgMinOpTest, GetMinArgOutput64) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgAxis64) {
|
||||
// Input Int32, Axis Int64, Output Int64
|
||||
ArgMinOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||
TensorType_INT64, TensorType_INT64);
|
||||
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||
model1.Invoke();
|
||||
|
||||
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int32
|
||||
ArgMinOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT32);
|
||||
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||
model2.Invoke();
|
||||
|
||||
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int64
|
||||
ArgMinOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT64);
|
||||
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||
model3.Invoke();
|
||||
|
||||
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user