Reading of half_pixel_centers attribute for TfLiteResizeNearestNeighborParams in model_builder.

Added support of half_pixel_centers for ResizeNearest in GPU backends.

PiperOrigin-RevId: 324134037
Change-Id: I52ee5b02f50f20a5e409bd01c90f1d1d332fa50a
This commit is contained in:
Raman Sarokin 2020-07-30 19:40:06 -07:00 committed by TensorFlower Gardener
parent 3223ec93a0
commit a904462bf4
9 changed files with 276 additions and 31 deletions

View File

@ -36,8 +36,7 @@ Resize& Resize::operator=(Resize&& operation) {
}
std::string Resize::GetResizeCode(const OperationDef& op_def,
SamplingType sampling_type,
bool half_pixel_centers) {
const Resize2DAttributes& attr) {
auto src_desc = op_def.src_tensors[0];
if (op_def.IsBatchSupported()) {
src_desc.SetStateVar("BatchedWidth", "true");
@ -69,16 +68,34 @@ std::string Resize::GetResizeCode(const OperationDef& op_def,
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
"|| Z >= args.dst_tensor.Slices()) return;\n";
}
if (sampling_type == SamplingType::NEAREST) {
c += " int2 coord = (int2)(X * args.scale_factor_x, Y * "
"args.scale_factor_y);\n";
if (attr.type == SamplingType::NEAREST) {
std::string fxc;
std::string fyc;
if (attr.half_pixel_centers) {
fxc = "(X + 0.5f) * args.scale_factor_x";
fyc = "(Y + 0.5f) * args.scale_factor_y";
} else {
fxc = "X * args.scale_factor_x";
fyc = "Y * args.scale_factor_y";
}
if (attr.align_corners) {
fxc += " + 0.5f";
fyc += " + 0.5f";
}
c += " int2 coord;\n";
c += " coord.x = (int)(" + fxc + ");\n";
c += " coord.y = (int)(" + fyc + ");\n";
c += " coord.x = max(0, coord.x);\n";
c += " coord.y = max(0, coord.y);\n";
c += " coord.x = min(coord.x, args.border_x);\n";
c += " coord.y = min(coord.y, args.border_y);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
}
c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
} else {
if (half_pixel_centers) {
if (attr.half_pixel_centers) {
c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * "
"(float2)(args.scale_factor_x, args.scale_factor_y) - "
"0.5f;\n";
@ -111,8 +128,7 @@ std::string Resize::GetResizeCode(const OperationDef& op_def,
}
absl::Status Resize::Compile(const CreationContext& creation_context) {
std::string code =
GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers);
std::string code = GetResizeCode(definition_, attr_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
@ -160,7 +176,7 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) {
}
std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
SamplingType sampling_type) {
const Resize3DAttributes& attr) {
auto src_desc = op_def.src_tensors[0];
if (op_def.IsBatchSupported()) {
src_desc.SetStateVar("BatchedWidth", "true");
@ -196,10 +212,34 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
"|| Z >= args.dst_tensor.Depth()) return;\n";
}
if (sampling_type == SamplingType::NEAREST) {
c += " int4 coord = (int4)(X * args.scale_factor_x, Y * "
"args.scale_factor_y, Z * "
"args.scale_factor_z, 0);\n";
if (attr.type == SamplingType::NEAREST) {
std::string fxc;
std::string fyc;
std::string fzc;
if (attr.half_pixel_centers) {
fxc = "(X + 0.5f) * args.scale_factor_x";
fyc = "(Y + 0.5f) * args.scale_factor_y";
fzc = "(Z + 0.5f) * args.scale_factor_z";
} else {
fxc = "X * args.scale_factor_x";
fyc = "Y * args.scale_factor_y";
fzc = "Z * args.scale_factor_z";
}
if (attr.align_corners) {
fxc += " + 0.5f";
fyc += " + 0.5f";
fzc += " + 0.5f";
}
c += " int4 coord;\n";
c += " coord.x = (int)(" + fxc + ");\n";
c += " coord.y = (int)(" + fyc + ");\n";
c += " coord.z = (int)(" + fzc + ");\n";
c += " coord.x = max(0, coord.x);\n";
c += " coord.y = max(0, coord.y);\n";
c += " coord.z = max(0, coord.z);\n";
c += " coord.x = min(coord.x, args.border_x);\n";
c += " coord.y = min(coord.y, args.border_y);\n";
c += " coord.z = min(coord.z, args.border_z);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
@ -249,7 +289,7 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
}
absl::Status Resize3D::Compile(const CreationContext& creation_context) {
std::string code = GetResize3DCode(definition_, attr_.type);
std::string code = GetResize3DCode(definition_, attr_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));

View File

@ -45,8 +45,7 @@ class Resize : public GPUOperation {
: GPUOperation(definition), attr_(attr) {}
std::string GetResizeCode(const OperationDef& op_def,
SamplingType sampling_type,
bool half_pixel_centers);
const Resize2DAttributes& attr);
Resize2DAttributes attr_;
};
@ -74,7 +73,7 @@ class Resize3D : public GPUOperation {
: GPUOperation(definition), attr_(attr) {}
std::string GetResize3DCode(const OperationDef& op_def,
SamplingType sampling_type);
const Resize3DAttributes& attr);
Resize3DAttributes attr_;
};

View File

@ -161,6 +161,7 @@ TEST_F(OpenCLOperationTest, ResizeNearest) {
Resize2DAttributes attr;
attr.align_corners = false;
attr.half_pixel_centers = false;
attr.new_shape = HW(2, 4);
attr.type = SamplingType::NEAREST;
@ -183,6 +184,66 @@ TEST_F(OpenCLOperationTest, ResizeNearest) {
}
}
TEST_F(OpenCLOperationTest, ResizeNearestAlignCorners) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 1);
src_tensor.data = {3.0f, 6.0f, 9.0f, 12.0f};
Resize2DAttributes attr;
attr.align_corners = true;
attr.half_pixel_centers = false;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-5f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
Resize operation = CreateResize(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 3, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f,
12.0f, 9.0f, 12.0f, 12.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ResizeNearestHalfPixelCenters) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 1);
src_tensor.data = {3.0f, 6.0f, 9.0f, 12.0f};
Resize2DAttributes attr;
attr.align_corners = false;
attr.half_pixel_centers = true;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-5f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
Resize operation = CreateResize(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 3, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f,
12.0f, 9.0f, 12.0f, 12.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -1638,7 +1638,9 @@ class Resize2DOperationParser : public TFLiteOperationParser {
}
*half_pixel_centers = tf_options->half_pixel_centers;
} else {
*half_pixel_centers = false;
const TfLiteResizeNearestNeighborParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*half_pixel_centers = tf_options->half_pixel_centers;
}
return absl::OkStatus();
}

View File

@ -388,8 +388,7 @@ struct Resize2DAttributes {
// If true, the centers of the 4 corner pixels of the input and output tensors
// are aligned, preserving the values at the corner pixels. Defaults to false.
bool align_corners = false;
// half_pixel_centers assumes pixels are of half the actual dimensions, and
// yields more accurate resizes. Only applicable to BILINEAR sampling.
bool half_pixel_centers = false;
};
@ -402,8 +401,7 @@ struct Resize3DAttributes {
// If true, the centers of the 8 corner pixels of the input and output tensors
// are aligned, preserving the values at the corner pixels. Defaults to false.
bool align_corners = false;
// half_pixel_centers assumes pixels are of half the actual dimensions, and
// yields more accurate resizes. Only applicable to BILINEAR sampling.
bool half_pixel_centers = false;
};

View File

@ -97,8 +97,27 @@ class Resize : public NodeShader {
value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
} else if (attr.type == SamplingType::NEAREST) {
source = R"(
ivec2 coord = ivec2(vec2(gid.xy) * $scale_factor$);
std::string fxc;
std::string fyc;
if (attr.half_pixel_centers) {
fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
} else {
fxc = "float(gid.x) * $scale_factor.x$";
fyc = "float(gid.y) * $scale_factor.y$";
}
if (attr.align_corners) {
fxc += " + 0.5";
fyc += " + 0.5";
}
source += " ivec2 coord;\n";
source += " coord.x = int(" + fxc + ");\n";
source += " coord.y = int(" + fyc + ");\n";
source += " coord.x = max(0, coord.x);\n";
source += " coord.y = max(0, coord.y);\n";
source += " coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
source += " coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
source += R"(
value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
)";
} else {

View File

@ -180,6 +180,58 @@ TEST(ResizeTest, Nearest1x2x1To2x4x1) {
Pointwise(FloatNear(1e-6), {1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}));
}
TEST(ResizeTest, NearestAlignCorners) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 3, 3, 1);
Resize2DAttributes attr;
attr.align_corners = true;
attr.half_pixel_centers = false;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f,
9.0f, 12.0f, 12.0f}));
}
TEST(ResizeTest, NearestHalfPixelCenters) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 3, 3, 1);
Resize2DAttributes attr;
attr.align_corners = false;
attr.half_pixel_centers = true;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f,
9.0f, 12.0f, 12.0f}));
}
} // namespace
} // namespace gl
} // namespace gpu

View File

@ -31,7 +31,7 @@ namespace tflite {
namespace gpu {
namespace metal {
std::string GetResizeBilinearCode(bool half_pixel_centers) {
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
@ -42,7 +42,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
return;
})";
if (half_pixel_centers) {
if (attr.half_pixel_centers) {
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
} else {
code += "const float2 tex_coord = float2(gid.xy) * scale;";
@ -74,8 +74,8 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) {
return code;
}
std::string GetResizeNearestCode() {
return R"(
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
$0
@ -85,7 +85,28 @@ std::string GetResizeNearestCode() {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
return;
}
const int2 coord = int2(float2(gid.xy) * scale);
)";
std::string fxc;
std::string fyc;
if (attr.half_pixel_centers) {
fxc = "(float(gid.x) + 0.5f) * scale.x";
fyc = "(float(gid.y) + 0.5f) * scale.y";
} else {
fxc = "float(gid.x) * scale.x";
fyc = "float(gid.y) * scale.y";
}
if (attr.align_corners) {
fxc += " + 0.5f";
fyc += " + 0.5f";
}
code += " int2 coord;\n";
code += " coord.x = static_cast<int>(" + fxc + ");\n";
code += " coord.y = static_cast<int>(" + fyc + ");\n";
code += " coord.x = max(0, coord.x);\n";
code += " coord.y = max(0, coord.y);\n";
code += " coord.x = min(coord.x, size.x - 1);\n";
code += " coord.y = min(coord.y, size.y - 1);\n";
code += R"(
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
FLT4 value = src_buffer[src_index];
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
@ -93,6 +114,7 @@ std::string GetResizeNearestCode() {
output_buffer[linear_index] = value;
}
)";
return code;
}
std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
@ -103,10 +125,10 @@ std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
desc->is_linkable = false;
switch (attr.type) {
case SamplingType::BILINEAR:
desc->shader_source = GetResizeBilinearCode(attr.half_pixel_centers);
desc->shader_source = GetResizeBilinearCode(attr);
break;
case SamplingType::NEAREST:
desc->shader_source = GetResizeNearestCode();
desc->shader_source = GetResizeNearestCode(attr);
break;
default:
// Unknown sampling type

View File

@ -196,4 +196,56 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testResizeNearestAlignCorners {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 3, 3, 1);
Resize2DAttributes attr;
attr.align_corners = true;
attr.half_pixel_centers = false;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f, 9.0f, 12.0f, 12.0f},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testResizeNearestHalfPixelCenters {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 3, 3, 1);
Resize2DAttributes attr;
attr.align_corners = false;
attr.half_pixel_centers = true;
attr.new_shape = HW(3, 3);
attr.type = SamplingType::NEAREST;
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f, 9.0f, 12.0f, 12.0f},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end