Add OpenGL support for half_pixel_centers in Resize
PiperOrigin-RevId: 294571167 Change-Id: I4b9602d8f4b10d575994545e803f4ca6b0a3de55
This commit is contained in:
parent
21a7c41fd1
commit
da57baa3d4
@ -40,7 +40,7 @@ TFLite on GPU supports the following ops in 16-bit and 32-bit float precision:
|
|||||||
* `RELU v1`
|
* `RELU v1`
|
||||||
* `RELU6 v1`
|
* `RELU6 v1`
|
||||||
* `RESHAPE v1`
|
* `RESHAPE v1`
|
||||||
* `RESIZE_BILINEAR v1`
|
* `RESIZE_BILINEAR v1-3`
|
||||||
* `SOFTMAX v1`
|
* `SOFTMAX v1`
|
||||||
* `STRIDED_SLICE v1`
|
* `STRIDED_SLICE v1`
|
||||||
* `SUB v1`
|
* `SUB v1`
|
||||||
|
@ -1701,13 +1701,15 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
||||||
|
|
||||||
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
|
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
|
||||||
bool align_corners;
|
bool align_corners;
|
||||||
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
|
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
|
||||||
|
bool half_pixel_centers;
|
||||||
|
RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1723,6 +1725,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||||||
|
|
||||||
Resize2DAttributes attr;
|
Resize2DAttributes attr;
|
||||||
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
|
RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
|
||||||
attr.type = sampling_type_;
|
attr.type = sampling_type_;
|
||||||
attr.new_shape.CopyAllDefinedAxis(
|
attr.new_shape.CopyAllDefinedAxis(
|
||||||
graph->FindOutputs(node->id)[0]->tensor.shape);
|
graph->FindOutputs(node->id)[0]->tensor.shape);
|
||||||
@ -1758,6 +1762,25 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
|
||||||
|
bool* half_pixel_centers) {
|
||||||
|
if (sampling_type_ == SamplingType::BILINEAR) {
|
||||||
|
const auto* tf_options = reinterpret_cast<TfLiteResizeBilinearParams*>(
|
||||||
|
tflite_node->builtin_data);
|
||||||
|
if (!tf_options) {
|
||||||
|
return InternalError("Missing tflite params for ResizeBilinear op");
|
||||||
|
}
|
||||||
|
if (tf_options->align_corners && tf_options->half_pixel_centers) {
|
||||||
|
return InternalError(
|
||||||
|
"If half_pixel_centers is True, align_corners must be False.");
|
||||||
|
}
|
||||||
|
*half_pixel_centers = tf_options->half_pixel_centers;
|
||||||
|
} else {
|
||||||
|
*half_pixel_centers = false;
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
|
Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node) {
|
const TfLiteNode* tflite_node) {
|
||||||
const auto* input = context->tensors + tflite_node->inputs->data[0];
|
const auto* input = context->tensors + tflite_node->inputs->data[0];
|
||||||
|
@ -370,6 +370,9 @@ struct Resize2DAttributes {
|
|||||||
// If true, the centers of the 4 corner pixels of the input and output tensors
|
// If true, the centers of the 4 corner pixels of the input and output tensors
|
||||||
// are aligned, preserving the values at the corner pixels. Defaults to false.
|
// are aligned, preserving the values at the corner pixels. Defaults to false.
|
||||||
bool align_corners = false;
|
bool align_corners = false;
|
||||||
|
// half_pixel_centers assumes pixels are of half the actual dimensions, and
|
||||||
|
// yields more accurate resizes. Only applicable to BILINEAR sampling.
|
||||||
|
bool half_pixel_centers = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/147771327): rename to Resize3D
|
// TODO(b/147771327): rename to Resize3D
|
||||||
|
@ -80,14 +80,20 @@ class Resize : public NodeShader {
|
|||||||
|
|
||||||
std::string source;
|
std::string source;
|
||||||
if (attr.type == SamplingType::BILINEAR) {
|
if (attr.type == SamplingType::BILINEAR) {
|
||||||
source = R"(
|
if (attr.half_pixel_centers) {
|
||||||
vec2 coord = vec2(gid.xy) * $scale_factor$;
|
source = "vec2 coord = (vec2(gid.xy) + 0.5) * $scale_factor$ - 0.5;";
|
||||||
|
} else {
|
||||||
|
source = "vec2 coord = vec2(gid.xy) * $scale_factor$;";
|
||||||
|
}
|
||||||
|
source += R"(
|
||||||
|
vec2 coord_floor = floor(coord);
|
||||||
|
ivec2 icoord_floor = ivec2(coord_floor);
|
||||||
ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
|
ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
|
||||||
ivec4 st;
|
ivec4 st;
|
||||||
st.xy = ivec2(coord);
|
st.xy = max(icoord_floor, ivec2(0, 0));
|
||||||
st.zw = min(st.xy + ivec2(1, 1), borders);
|
st.zw = min(icoord_floor + ivec2(1, 1), borders);
|
||||||
|
|
||||||
vec2 t = coord - vec2(st.xy); //interpolating factors
|
vec2 t = coord - coord_floor; //interpolating factors
|
||||||
|
|
||||||
vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
|
vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
|
||||||
vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
|
vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
|
||||||
|
@ -103,6 +103,58 @@ TEST(ResizeTest, Bilinear2x2x1To4x4x1) {
|
|||||||
7.0, 8.0, 8.0, 6.0, 7.0, 8.0, 8.0}));
|
7.0, 8.0, 8.0, 6.0, 7.0, 8.0, 8.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ResizeTest, Bilinear2x2x1To3x3x1WithoutHalfPixel) {
|
||||||
|
TensorRef<BHWC> input;
|
||||||
|
input.type = DataType::FLOAT32;
|
||||||
|
input.ref = 0;
|
||||||
|
input.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
|
TensorRef<BHWC> output;
|
||||||
|
output.type = DataType::FLOAT32;
|
||||||
|
output.ref = 1;
|
||||||
|
output.shape = BHWC(1, 3, 3, 1);
|
||||||
|
|
||||||
|
Resize2DAttributes attr;
|
||||||
|
attr.align_corners = false;
|
||||||
|
attr.half_pixel_centers = false;
|
||||||
|
attr.new_shape = HW(3, 3);
|
||||||
|
attr.type = SamplingType::BILINEAR;
|
||||||
|
|
||||||
|
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
|
||||||
|
{output});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 1.666666, 2.0, 2.333333, 3.0,
|
||||||
|
3.333333, 3.0, 3.666666, 4.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeTest, Bilinear2x2x1To3x3x1WithHalfPixel) {
|
||||||
|
TensorRef<BHWC> input;
|
||||||
|
input.type = DataType::FLOAT32;
|
||||||
|
input.ref = 0;
|
||||||
|
input.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
|
TensorRef<BHWC> output;
|
||||||
|
output.type = DataType::FLOAT32;
|
||||||
|
output.ref = 1;
|
||||||
|
output.shape = BHWC(1, 3, 3, 1);
|
||||||
|
|
||||||
|
Resize2DAttributes attr;
|
||||||
|
attr.align_corners = false;
|
||||||
|
attr.half_pixel_centers = true;
|
||||||
|
attr.new_shape = HW(3, 3);
|
||||||
|
attr.type = SamplingType::BILINEAR;
|
||||||
|
|
||||||
|
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
|
||||||
|
{output});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{1.0, 1.5, 2.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ResizeTest, Nearest1x2x1To2x4x1) {
|
TEST(ResizeTest, Nearest1x2x1To2x4x1) {
|
||||||
TensorRef<BHWC> input;
|
TensorRef<BHWC> input;
|
||||||
input.type = DataType::FLOAT32;
|
input.type = DataType::FLOAT32;
|
||||||
|
Loading…
Reference in New Issue
Block a user