axis test cases for argminmax
This commit is contained in:
parent
f09580c836
commit
c3606c810a
@ -143,6 +143,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
|||||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
|
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(context,
|
||||||
|
"Only float32, uint8, int8 and int32 are "
|
||||||
|
"supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -161,10 +165,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
|||||||
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
|
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(context,
|
||||||
|
"Only float32, uint8, int8 and int32 are "
|
||||||
|
"supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(
|
||||||
|
context, "Only int32 and int64 are supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(output->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -177,10 +188,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
|
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
|
||||||
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
|
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(context,
|
||||||
|
"Only float32, uint8, int8 and int32 are "
|
||||||
|
"supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -192,14 +210,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
|
TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
|
||||||
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
|
TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(context,
|
||||||
|
"Only float32, uint8, int8 and int32 are "
|
||||||
|
"supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
context->ReportError(
|
||||||
|
context, "Only int32 and int64 are supported currently, got %s.",
|
||||||
|
TfLiteTypeGetName(output->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,9 +27,9 @@ template <typename T>
|
|||||||
class ArgBaseOpModel : public SingleOpModel {
|
class ArgBaseOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||||
TensorType output_type, TensorType index_output_type) {
|
TensorType axis_type, TensorType output_type) {
|
||||||
input_ = AddInput(input_type);
|
input_ = AddInput(input_type);
|
||||||
axis_ = AddInput(TensorType_INT32);
|
axis_ = AddInput(axis_type);
|
||||||
output_ = AddOutput(output_type);
|
output_ = AddOutput(output_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,13 +49,11 @@ template <typename T>
|
|||||||
class ArgMaxOpModel : public ArgBaseOpModel<T> {
|
class ArgMaxOpModel : public ArgBaseOpModel<T> {
|
||||||
public:
|
public:
|
||||||
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||||
TensorType output_type, TensorType index_output_type)
|
TensorType axis_type, TensorType output_type)
|
||||||
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
|
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||||
index_output_type) {
|
|
||||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||||
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
|
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
|
||||||
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
|
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||||
.Union());
|
|
||||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -64,13 +62,11 @@ template <typename T>
|
|||||||
class ArgMinOpModel : public ArgBaseOpModel<T> {
|
class ArgMinOpModel : public ArgBaseOpModel<T> {
|
||||||
public:
|
public:
|
||||||
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||||
TensorType output_type, TensorType index_output_type)
|
TensorType axis_type, TensorType output_type)
|
||||||
: ArgBaseOpModel<T>(input_shape, input_type, output_type,
|
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||||
index_output_type) {
|
|
||||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||||
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
|
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
|
||||||
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
|
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||||
.Union());
|
|
||||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -142,7 +138,7 @@ TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
||||||
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
|
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||||
TensorType_INT64);
|
TensorType_INT64);
|
||||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
model.PopulateTensor<int>(model.axis(), {3});
|
model.PopulateTensor<int>(model.axis(), {3});
|
||||||
@ -152,6 +148,38 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ArgMaxOpTest, GetMaxArgAxis64) {
|
||||||
|
// Input Int32, Axis Int64, Output Int64
|
||||||
|
ArgMaxOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||||
|
TensorType_INT64, TensorType_INT64);
|
||||||
|
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||||
|
model1.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({0, 1}));
|
||||||
|
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
|
||||||
|
// Input Int8, Axis Int64, Output Int32
|
||||||
|
ArgMaxOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||||
|
TensorType_INT32);
|
||||||
|
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||||
|
model2.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({0, 1}));
|
||||||
|
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
|
||||||
|
// Input Int8, Axis Int64, Output Int64
|
||||||
|
ArgMaxOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||||
|
TensorType_INT64);
|
||||||
|
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||||
|
model3.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({0, 1}));
|
||||||
|
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ArgMinOpTest, GetMinArgFloat) {
|
TEST(ArgMinOpTest, GetMinArgFloat) {
|
||||||
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
|
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
|
||||||
TensorType_INT32, TensorType_INT32);
|
TensorType_INT32, TensorType_INT32);
|
||||||
@ -197,7 +225,7 @@ TEST(ArgMinOpTest, GetMinArgNegativeAxis) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArgMinOpTest, GetMinArgOutput64) {
|
TEST(ArgMinOpTest, GetMinArgOutput64) {
|
||||||
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
|
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||||
TensorType_INT64);
|
TensorType_INT64);
|
||||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
model.PopulateTensor<int>(model.axis(), {3});
|
model.PopulateTensor<int>(model.axis(), {3});
|
||||||
@ -207,6 +235,38 @@ TEST(ArgMinOpTest, GetMinArgOutput64) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ArgMinOpTest, GetMinArgAxis64) {
|
||||||
|
// Input Int32, Axis Int64, Output Int64
|
||||||
|
ArgMinOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||||
|
TensorType_INT64, TensorType_INT64);
|
||||||
|
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||||
|
model1.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({1, 0}));
|
||||||
|
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
|
||||||
|
// Input Int8, Axis Int64, Output Int32
|
||||||
|
ArgMinOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||||
|
TensorType_INT32);
|
||||||
|
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||||
|
model2.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({1, 0}));
|
||||||
|
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
|
||||||
|
// Input Int8, Axis Int64, Output Int64
|
||||||
|
ArgMinOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||||
|
TensorType_INT64);
|
||||||
|
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||||
|
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||||
|
model3.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({1, 0}));
|
||||||
|
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user