axis test cases for argminmax

This commit is contained in:
Pariksheet 2019-03-28 18:12:47 +05:30
parent f09580c836
commit c3606c810a
2 changed files with 102 additions and 14 deletions

View File

@ -143,6 +143,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
default:
context->ReportError(context,
"Only float32, uint8, int8 and int32 are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} break;
@ -161,10 +165,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;
default:
context->ReportError(context,
"Only float32, uint8, int8 and int32 are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} break;
default:
context->ReportError(
context, "Only int32 and int64 are supported currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
@ -177,10 +188,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
break;
case kTfLiteInt8:
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
break;
default:
context->ReportError(context,
"Only float32, uint8, int8 and int32 are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} break;
@ -192,14 +210,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
break;
case kTfLiteInt8:
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
break;
default:
context->ReportError(context,
"Only float32, uint8, int8 and int32 are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} break;
default:
context->ReportError(
context, "Only int32 and int64 are supported currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}

View File

@ -27,9 +27,9 @@ template <typename T>
class ArgBaseOpModel : public SingleOpModel {
public:
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType output_type, TensorType index_output_type) {
TensorType axis_type, TensorType output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
axis_ = AddInput(axis_type);
output_ = AddOutput(output_type);
}
@ -49,13 +49,11 @@ template <typename T>
class ArgMaxOpModel : public ArgBaseOpModel<T> {
public:
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType output_type, TensorType index_output_type)
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
index_output_type) {
TensorType axis_type, TensorType output_type)
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
ArgBaseOpModel<T>::SetBuiltinOp(
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
.Union());
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
};
@ -64,13 +62,11 @@ template <typename T>
class ArgMinOpModel : public ArgBaseOpModel<T> {
public:
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType output_type, TensorType index_output_type)
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
index_output_type) {
TensorType axis_type, TensorType output_type)
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
ArgBaseOpModel<T>::SetBuiltinOp(
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
.Union());
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
};
@ -142,7 +138,7 @@ TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) {
}
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT64);
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
@ -152,6 +148,38 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMaxOpTest, GetMaxArgAxis64) {
// Input Int32, Axis Int64, Output Int64
ArgMaxOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
TensorType_INT64, TensorType_INT64);
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model1.PopulateTensor<int64_t>(model1.axis(), {3});
model1.Invoke();
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int32
ArgMaxOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT32);
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model2.PopulateTensor<int64_t>(model2.axis(), {3});
model2.Invoke();
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int64
ArgMaxOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT64);
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model3.PopulateTensor<int64_t>(model3.axis(), {3});
model3.Invoke();
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMinOpTest, GetMinArgFloat) {
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@ -197,7 +225,7 @@ TEST(ArgMinOpTest, GetMinArgNegativeAxis) {
}
TEST(ArgMinOpTest, GetMinArgOutput64) {
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT64);
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
@ -207,6 +235,38 @@ TEST(ArgMinOpTest, GetMinArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMinOpTest, GetMinArgAxis64) {
// Input Int32, Axis Int64, Output Int64
ArgMinOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
TensorType_INT64, TensorType_INT64);
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model1.PopulateTensor<int64_t>(model1.axis(), {3});
model1.Invoke();
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int32
ArgMinOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT32);
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model2.PopulateTensor<int64_t>(model2.axis(), {3});
model2.Invoke();
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int64
ArgMinOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT64);
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model3.PopulateTensor<int64_t>(model3.axis(), {3});
model3.Invoke();
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
} // namespace
} // namespace tflite