Reading of half_pixel_centers attribute for TfLiteResizeNearestNeighborParams in model_builder.
Added support of half_pixel_centers for ResizeNearest in GPU backends. PiperOrigin-RevId: 324134037 Change-Id: I52ee5b02f50f20a5e409bd01c90f1d1d332fa50a
This commit is contained in:
parent
3223ec93a0
commit
a904462bf4
@ -36,8 +36,7 @@ Resize& Resize::operator=(Resize&& operation) {
|
||||
}
|
||||
|
||||
std::string Resize::GetResizeCode(const OperationDef& op_def,
|
||||
SamplingType sampling_type,
|
||||
bool half_pixel_centers) {
|
||||
const Resize2DAttributes& attr) {
|
||||
auto src_desc = op_def.src_tensors[0];
|
||||
if (op_def.IsBatchSupported()) {
|
||||
src_desc.SetStateVar("BatchedWidth", "true");
|
||||
@ -69,16 +68,34 @@ std::string Resize::GetResizeCode(const OperationDef& op_def,
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
|
||||
"|| Z >= args.dst_tensor.Slices()) return;\n";
|
||||
}
|
||||
if (sampling_type == SamplingType::NEAREST) {
|
||||
c += " int2 coord = (int2)(X * args.scale_factor_x, Y * "
|
||||
"args.scale_factor_y);\n";
|
||||
if (attr.type == SamplingType::NEAREST) {
|
||||
std::string fxc;
|
||||
std::string fyc;
|
||||
if (attr.half_pixel_centers) {
|
||||
fxc = "(X + 0.5f) * args.scale_factor_x";
|
||||
fyc = "(Y + 0.5f) * args.scale_factor_y";
|
||||
} else {
|
||||
fxc = "X * args.scale_factor_x";
|
||||
fyc = "Y * args.scale_factor_y";
|
||||
}
|
||||
if (attr.align_corners) {
|
||||
fxc += " + 0.5f";
|
||||
fyc += " + 0.5f";
|
||||
}
|
||||
c += " int2 coord;\n";
|
||||
c += " coord.x = (int)(" + fxc + ");\n";
|
||||
c += " coord.y = (int)(" + fyc + ");\n";
|
||||
c += " coord.x = max(0, coord.x);\n";
|
||||
c += " coord.y = max(0, coord.y);\n";
|
||||
c += " coord.x = min(coord.x, args.border_x);\n";
|
||||
c += " coord.y = min(coord.y, args.border_y);\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
|
||||
c += " X = X * args.src_tensor.Batch() + B;\n";
|
||||
}
|
||||
c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
|
||||
} else {
|
||||
if (half_pixel_centers) {
|
||||
if (attr.half_pixel_centers) {
|
||||
c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * "
|
||||
"(float2)(args.scale_factor_x, args.scale_factor_y) - "
|
||||
"0.5f;\n";
|
||||
@ -111,8 +128,7 @@ std::string Resize::GetResizeCode(const OperationDef& op_def,
|
||||
}
|
||||
|
||||
absl::Status Resize::Compile(const CreationContext& creation_context) {
|
||||
std::string code =
|
||||
GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers);
|
||||
std::string code = GetResizeCode(definition_, attr_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
@ -160,7 +176,7 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) {
|
||||
}
|
||||
|
||||
std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
|
||||
SamplingType sampling_type) {
|
||||
const Resize3DAttributes& attr) {
|
||||
auto src_desc = op_def.src_tensors[0];
|
||||
if (op_def.IsBatchSupported()) {
|
||||
src_desc.SetStateVar("BatchedWidth", "true");
|
||||
@ -196,10 +212,34 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
|
||||
"|| Z >= args.dst_tensor.Depth()) return;\n";
|
||||
}
|
||||
if (sampling_type == SamplingType::NEAREST) {
|
||||
c += " int4 coord = (int4)(X * args.scale_factor_x, Y * "
|
||||
"args.scale_factor_y, Z * "
|
||||
"args.scale_factor_z, 0);\n";
|
||||
if (attr.type == SamplingType::NEAREST) {
|
||||
std::string fxc;
|
||||
std::string fyc;
|
||||
std::string fzc;
|
||||
if (attr.half_pixel_centers) {
|
||||
fxc = "(X + 0.5f) * args.scale_factor_x";
|
||||
fyc = "(Y + 0.5f) * args.scale_factor_y";
|
||||
fzc = "(Z + 0.5f) * args.scale_factor_z";
|
||||
} else {
|
||||
fxc = "X * args.scale_factor_x";
|
||||
fyc = "Y * args.scale_factor_y";
|
||||
fzc = "Z * args.scale_factor_z";
|
||||
}
|
||||
if (attr.align_corners) {
|
||||
fxc += " + 0.5f";
|
||||
fyc += " + 0.5f";
|
||||
fzc += " + 0.5f";
|
||||
}
|
||||
c += " int4 coord;\n";
|
||||
c += " coord.x = (int)(" + fxc + ");\n";
|
||||
c += " coord.y = (int)(" + fyc + ");\n";
|
||||
c += " coord.z = (int)(" + fzc + ");\n";
|
||||
c += " coord.x = max(0, coord.x);\n";
|
||||
c += " coord.y = max(0, coord.y);\n";
|
||||
c += " coord.z = max(0, coord.z);\n";
|
||||
c += " coord.x = min(coord.x, args.border_x);\n";
|
||||
c += " coord.y = min(coord.y, args.border_y);\n";
|
||||
c += " coord.z = min(coord.z, args.border_z);\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
|
||||
c += " X = X * args.src_tensor.Batch() + B;\n";
|
||||
@ -249,7 +289,7 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
|
||||
}
|
||||
|
||||
absl::Status Resize3D::Compile(const CreationContext& creation_context) {
|
||||
std::string code = GetResize3DCode(definition_, attr_.type);
|
||||
std::string code = GetResize3DCode(definition_, attr_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
|
@ -45,8 +45,7 @@ class Resize : public GPUOperation {
|
||||
: GPUOperation(definition), attr_(attr) {}
|
||||
|
||||
std::string GetResizeCode(const OperationDef& op_def,
|
||||
SamplingType sampling_type,
|
||||
bool half_pixel_centers);
|
||||
const Resize2DAttributes& attr);
|
||||
|
||||
Resize2DAttributes attr_;
|
||||
};
|
||||
@ -74,7 +73,7 @@ class Resize3D : public GPUOperation {
|
||||
: GPUOperation(definition), attr_(attr) {}
|
||||
|
||||
std::string GetResize3DCode(const OperationDef& op_def,
|
||||
SamplingType sampling_type);
|
||||
const Resize3DAttributes& attr);
|
||||
|
||||
Resize3DAttributes attr_;
|
||||
};
|
||||
|
@ -161,6 +161,7 @@ TEST_F(OpenCLOperationTest, ResizeNearest) {
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = false;
|
||||
attr.half_pixel_centers = false;
|
||||
attr.new_shape = HW(2, 4);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
@ -183,6 +184,66 @@ TEST_F(OpenCLOperationTest, ResizeNearest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ResizeNearestAlignCorners) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 1);
|
||||
src_tensor.data = {3.0f, 6.0f, 9.0f, 12.0f};
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = true;
|
||||
attr.half_pixel_centers = false;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-5f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Resize operation = CreateResize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 3, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f,
|
||||
12.0f, 9.0f, 12.0f, 12.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ResizeNearestHalfPixelCenters) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 1);
|
||||
src_tensor.data = {3.0f, 6.0f, 9.0f, 12.0f};
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = false;
|
||||
attr.half_pixel_centers = true;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-5f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
Resize operation = CreateResize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 3, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f,
|
||||
12.0f, 9.0f, 12.0f, 12.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -1638,7 +1638,9 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
*half_pixel_centers = tf_options->half_pixel_centers;
|
||||
} else {
|
||||
*half_pixel_centers = false;
|
||||
const TfLiteResizeNearestNeighborParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
*half_pixel_centers = tf_options->half_pixel_centers;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -388,8 +388,7 @@ struct Resize2DAttributes {
|
||||
// If true, the centers of the 4 corner pixels of the input and output tensors
|
||||
// are aligned, preserving the values at the corner pixels. Defaults to false.
|
||||
bool align_corners = false;
|
||||
// half_pixel_centers assumes pixels are of half the actual dimensions, and
|
||||
// yields more accurate resizes. Only applicable to BILINEAR sampling.
|
||||
|
||||
bool half_pixel_centers = false;
|
||||
};
|
||||
|
||||
@ -402,8 +401,7 @@ struct Resize3DAttributes {
|
||||
// If true, the centers of the 8 corner pixels of the input and output tensors
|
||||
// are aligned, preserving the values at the corner pixels. Defaults to false.
|
||||
bool align_corners = false;
|
||||
// half_pixel_centers assumes pixels are of half the actual dimensions, and
|
||||
// yields more accurate resizes. Only applicable to BILINEAR sampling.
|
||||
|
||||
bool half_pixel_centers = false;
|
||||
};
|
||||
|
||||
|
@ -97,8 +97,27 @@ class Resize : public NodeShader {
|
||||
|
||||
value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
|
||||
} else if (attr.type == SamplingType::NEAREST) {
|
||||
source = R"(
|
||||
ivec2 coord = ivec2(vec2(gid.xy) * $scale_factor$);
|
||||
std::string fxc;
|
||||
std::string fyc;
|
||||
if (attr.half_pixel_centers) {
|
||||
fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
|
||||
fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
|
||||
} else {
|
||||
fxc = "float(gid.x) * $scale_factor.x$";
|
||||
fyc = "float(gid.y) * $scale_factor.y$";
|
||||
}
|
||||
if (attr.align_corners) {
|
||||
fxc += " + 0.5";
|
||||
fyc += " + 0.5";
|
||||
}
|
||||
source += " ivec2 coord;\n";
|
||||
source += " coord.x = int(" + fxc + ");\n";
|
||||
source += " coord.y = int(" + fyc + ");\n";
|
||||
source += " coord.x = max(0, coord.x);\n";
|
||||
source += " coord.y = max(0, coord.y);\n";
|
||||
source += " coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
|
||||
source += " coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
|
||||
source += R"(
|
||||
value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
|
||||
)";
|
||||
} else {
|
||||
|
@ -180,6 +180,58 @@ TEST(ResizeTest, Nearest1x2x1To2x4x1) {
|
||||
Pointwise(FloatNear(1e-6), {1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0}));
|
||||
}
|
||||
|
||||
TEST(ResizeTest, NearestAlignCorners) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 3, 3, 1);
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = true;
|
||||
attr.half_pixel_centers = false;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
|
||||
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f,
|
||||
9.0f, 12.0f, 12.0f}));
|
||||
}
|
||||
|
||||
TEST(ResizeTest, NearestHalfPixelCenters) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 3, 3, 1);
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = false;
|
||||
attr.half_pixel_centers = true;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
|
||||
ASSERT_OK(model.Invoke(*NewResizeNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f,
|
||||
9.0f, 12.0f, 12.0f}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
|
@ -31,7 +31,7 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::string GetResizeBilinearCode(bool half_pixel_centers) {
|
||||
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
|
||||
std::string code = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
@ -42,7 +42,7 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) {
|
||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
||||
return;
|
||||
})";
|
||||
if (half_pixel_centers) {
|
||||
if (attr.half_pixel_centers) {
|
||||
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
|
||||
} else {
|
||||
code += "const float2 tex_coord = float2(gid.xy) * scale;";
|
||||
@ -74,8 +74,8 @@ std::string GetResizeBilinearCode(bool half_pixel_centers) {
|
||||
return code;
|
||||
}
|
||||
|
||||
std::string GetResizeNearestCode() {
|
||||
return R"(
|
||||
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
||||
std::string code = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
@ -85,7 +85,28 @@ std::string GetResizeNearestCode() {
|
||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
||||
return;
|
||||
}
|
||||
const int2 coord = int2(float2(gid.xy) * scale);
|
||||
)";
|
||||
std::string fxc;
|
||||
std::string fyc;
|
||||
if (attr.half_pixel_centers) {
|
||||
fxc = "(float(gid.x) + 0.5f) * scale.x";
|
||||
fyc = "(float(gid.y) + 0.5f) * scale.y";
|
||||
} else {
|
||||
fxc = "float(gid.x) * scale.x";
|
||||
fyc = "float(gid.y) * scale.y";
|
||||
}
|
||||
if (attr.align_corners) {
|
||||
fxc += " + 0.5f";
|
||||
fyc += " + 0.5f";
|
||||
}
|
||||
code += " int2 coord;\n";
|
||||
code += " coord.x = static_cast<int>(" + fxc + ");\n";
|
||||
code += " coord.y = static_cast<int>(" + fyc + ");\n";
|
||||
code += " coord.x = max(0, coord.x);\n";
|
||||
code += " coord.y = max(0, coord.y);\n";
|
||||
code += " coord.x = min(coord.x, size.x - 1);\n";
|
||||
code += " coord.y = min(coord.y, size.y - 1);\n";
|
||||
code += R"(
|
||||
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
|
||||
FLT4 value = src_buffer[src_index];
|
||||
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
|
||||
@ -93,6 +114,7 @@ std::string GetResizeNearestCode() {
|
||||
output_buffer[linear_index] = value;
|
||||
}
|
||||
)";
|
||||
return code;
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
|
||||
@ -103,10 +125,10 @@ std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
|
||||
desc->is_linkable = false;
|
||||
switch (attr.type) {
|
||||
case SamplingType::BILINEAR:
|
||||
desc->shader_source = GetResizeBilinearCode(attr.half_pixel_centers);
|
||||
desc->shader_source = GetResizeBilinearCode(attr);
|
||||
break;
|
||||
case SamplingType::NEAREST:
|
||||
desc->shader_source = GetResizeNearestCode();
|
||||
desc->shader_source = GetResizeNearestCode(attr);
|
||||
break;
|
||||
default:
|
||||
// Unknown sampling type
|
||||
|
@ -196,4 +196,56 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testResizeNearestAlignCorners {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 3, 3, 1);
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = true;
|
||||
attr.half_pixel_centers = false;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f, 9.0f, 12.0f, 12.0f},
|
||||
model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testResizeNearestHalfPixelCenters {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 3, 3, 1);
|
||||
|
||||
Resize2DAttributes attr;
|
||||
attr.align_corners = false;
|
||||
attr.half_pixel_centers = true;
|
||||
attr.new_shape = HW(3, 3);
|
||||
attr.type = SamplingType::NEAREST;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::RESIZE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {3.0f, 6.0f, 9.0f, 12.0f}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({3.0f, 6.0f, 6.0f, 9.0f, 12.0f, 12.0f, 9.0f, 12.0f, 12.0f},
|
||||
model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
x
Reference in New Issue
Block a user