Add int16x8 support for RESIZE_NEAREST_NEIGHBOR operator
This commit is contained in:
parent
00c6f88dd1
commit
860898de6f
@ -331,6 +331,8 @@ ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+AlignCorners.*/0,30
|
|||||||
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*/0,30
|
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*/0,30
|
||||||
// Only models with constant size tensor are accelerated
|
// Only models with constant size tensor are accelerated
|
||||||
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29
|
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29
|
||||||
|
// 16-bit tests are not supported
|
||||||
|
-ResizeNearestNeighborOpTest/.*Int16.*
|
||||||
|
|
||||||
# select_test
|
# select_test
|
||||||
-SelectOpTest/SelectBool
|
-SelectOpTest/SelectBool
|
||||||
|
@ -91,6 +91,17 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3) {
|
|||||||
output_shape, output_data);
|
output_shape, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) {
|
||||||
|
RuntimeShape input_shape = {1, 2, 2, 1};
|
||||||
|
std::vector<int16_t> input_data = {1, 2, 3, 4};
|
||||||
|
std::vector<int32> output_size_data = {3, 3};
|
||||||
|
RuntimeShape output_shape = {1, 3, 3, 1};
|
||||||
|
std::vector<int16_t> output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4};
|
||||||
|
|
||||||
|
TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data,
|
||||||
|
output_shape, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) {
|
TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) {
|
||||||
RuntimeShape input_shape = {1, 2, 2, 1};
|
RuntimeShape input_shape = {1, 2, 2, 1};
|
||||||
std::vector<uint8> input_data = {1, 2, 3, 4};
|
std::vector<uint8> input_data = {1, 2, 3, 4};
|
||||||
|
@ -123,7 +123,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
|
AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
|
||||||
Register_RESIZE_NEAREST_NEIGHBOR(),
|
Register_RESIZE_NEAREST_NEIGHBOR(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 3);
|
/* max_version = */ 4);
|
||||||
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
|
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
|
||||||
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
|
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
|
@ -121,6 +121,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(size), GetTensorData<int32>(size),
|
GetTensorShape(size), GetTensorData<int32>(size),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
|
} else if (output->type == kTfLiteInt16) {
|
||||||
|
reference_ops::ResizeNearestNeighbor(
|
||||||
|
op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||||
|
GetTensorShape(size), GetTensorData<int32>(size),
|
||||||
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(context,
|
TF_LITE_KERNEL_LOG(context,
|
||||||
"Output type is %s, requires float, uint8 or int8.",
|
"Output type is %s, requires float, uint8 or int8.",
|
||||||
|
@ -106,6 +106,14 @@ TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt8) {
|
|||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt16) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 1, 2, 1}}, {1, 3},
|
||||||
|
GetParam());
|
||||||
|
m.SetInput<int16_t>({-3, 6});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
||||||
|
}
|
||||||
TEST_P(ResizeNearestNeighborOpTest, VerticalResize) {
|
TEST_P(ResizeNearestNeighborOpTest, VerticalResize) {
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
|
||||||
GetParam());
|
GetParam());
|
||||||
@ -130,6 +138,14 @@ TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt8) {
|
|||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt16) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 1, 1}}, {3, 1},
|
||||||
|
GetParam());
|
||||||
|
m.SetInput<int16_t>({3, -9});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
||||||
|
}
|
||||||
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
|
||||||
GetParam());
|
GetParam());
|
||||||
@ -172,6 +188,20 @@ TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) {
|
|||||||
9, 9, 12, //
|
9, 9, 12, //
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt16) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 2, 1}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
|
m.SetInput<int16_t>({
|
||||||
|
3, -6, //
|
||||||
|
9, 12 //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 3, -6, //
|
||||||
|
3, 3, -6, //
|
||||||
|
9, 9, 12, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
|
||||||
GetParam());
|
GetParam());
|
||||||
@ -284,6 +314,25 @@ TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
|||||||
12, 12, 16, //
|
12, 12, 16, //
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt16) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_INT16, {2, 2, 2, 1}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
|
m.SetInput<int16_t>({
|
||||||
|
3, 6, //
|
||||||
|
9, -12, //
|
||||||
|
-4, 10, //
|
||||||
|
12, 16 //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 3, 6, //
|
||||||
|
3, 3, 6, //
|
||||||
|
9, 9, -12, //
|
||||||
|
-4, -4, 10, //
|
||||||
|
-4, -4, 10, //
|
||||||
|
12, 12, 16, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
|
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
|
||||||
GetParam());
|
GetParam());
|
||||||
@ -342,6 +391,20 @@ TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
|
|||||||
10, 12, 10, 12, -14, 16, //
|
10, 12, 10, 12, -14, 16, //
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt16) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 2, 2}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
|
m.SetInput<int16_t>({
|
||||||
|
3, 4, -6, 10, //
|
||||||
|
10, 12, -14, 16, //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 4, 3, 4, -6, 10, //
|
||||||
|
3, 4, 3, 4, -6, 10, //
|
||||||
|
10, 12, 10, 12, -14, 16, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
|
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
|
||||||
ResizeNearestNeighborOpTest,
|
ResizeNearestNeighborOpTest,
|
||||||
testing::Values(TestType::kConst, TestType::kDynamic));
|
testing::Values(TestType::kConst, TestType::kDynamic));
|
||||||
|
@ -145,6 +145,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kResizeNearestNeighbor, 1}, "1.13.1"},
|
{{OperatorType::kResizeNearestNeighbor, 1}, "1.13.1"},
|
||||||
{{OperatorType::kResizeNearestNeighbor, 2}, "1.14.0"},
|
{{OperatorType::kResizeNearestNeighbor, 2}, "1.14.0"},
|
||||||
{{OperatorType::kResizeNearestNeighbor, 3}, kPendingReleaseOpVersion},
|
{{OperatorType::kResizeNearestNeighbor, 3}, kPendingReleaseOpVersion},
|
||||||
|
{{OperatorType::kResizeNearestNeighbor, 4}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kSqueeze, 1}, "1.6.0"},
|
{{OperatorType::kSqueeze, 1}, "1.6.0"},
|
||||||
{{OperatorType::kSplit, 1}, "1.5.0"},
|
{{OperatorType::kSplit, 1}, "1.5.0"},
|
||||||
{{OperatorType::kSplit, 2}, "1.14.0"},
|
{{OperatorType::kSplit, 2}, "1.14.0"},
|
||||||
|
@ -868,13 +868,18 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.version = 1;
|
property.version = 1;
|
||||||
break;
|
break;
|
||||||
case BuiltinOperator_RESIZE_BILINEAR:
|
case BuiltinOperator_RESIZE_BILINEAR:
|
||||||
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
|
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
property.quantizable_int16 = false;
|
property.quantizable_int16 = false;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
|
||||||
|
property.inputs = {{0, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.version = 2;
|
||||||
|
break;
|
||||||
case BuiltinOperator_SHAPE:
|
case BuiltinOperator_SHAPE:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
// Shape has no quantizable output.
|
// Shape has no quantizable output.
|
||||||
|
@ -396,8 +396,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
|
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
|
||||||
if (op_sig.options.resize.half_pixel_centers ||
|
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||||
op_sig.options.resize.align_corners) {
|
return 4;
|
||||||
|
} else if (op_sig.options.resize.half_pixel_centers ||
|
||||||
|
op_sig.options.resize.align_corners) {
|
||||||
return 3;
|
return 3;
|
||||||
} else if (op_sig.input_types.at(0) == TensorType_INT8) {
|
} else if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -697,5 +697,13 @@ TEST(OpVersionTest, VersioningResizeNearestNeighborTest) {
|
|||||||
|
|
||||||
fake_op_sig.options.resize.align_corners = true;
|
fake_op_sig.options.resize.align_corners = true;
|
||||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
|
|
||||||
|
// int16 input is version 4.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT16, TensorType_INT32},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_INT16},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||||
}
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -186,6 +186,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||||||
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 1}, "1.13.1"},
|
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 1}, "1.13.1"},
|
||||||
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2}, "1.14.0"},
|
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 3}, "2.3.0"},
|
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 3}, "2.3.0"},
|
||||||
|
{{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 4}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_RNN, 1}, "1.5.0"},
|
{{BuiltinOperator_RNN, 1}, "1.5.0"},
|
||||||
{{BuiltinOperator_RNN, 2}, "1.14.0"},
|
{{BuiltinOperator_RNN, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_RNN, 3}, "2.3.0"},
|
{{BuiltinOperator_RNN, 3}, "2.3.0"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user