Added support of NEAREST sampling type for OpenCL Resize.

Fixed compilation of cl_test.

PiperOrigin-RevId: 291945642
Change-Id: I99b25d4e5898aa97725c1382c818aaf854077570
This commit is contained in:
Raman Sarokin 2020-01-28 09:09:06 -08:00 committed by TensorFlower Gardener
parent 770165c954
commit 7e124e7b66
3 changed files with 109 additions and 57 deletions
tensorflow/lite/delegates/gpu/cl/kernels

View File

@ -30,8 +30,9 @@ Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
std::vector<Tensor> src(src_cpu.size());
for (int i = 0; i < src_cpu.size(); ++i) {
auto src_shape = src_cpu[i].shape;
if (src_shape.b != 1 && !op_def.batch_support) {
return InvalidArgumentError("op_def.batch_support must be enabled");
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
return InvalidArgumentError(
"Layout doesn't have Batch dimension, but shape.b != 1");
}
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
*creation_context.device, src_shape,
@ -43,8 +44,9 @@ Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
std::vector<Tensor> dst(dst_cpu.size());
for (int i = 0; i < dst_cpu.size(); ++i) {
auto dst_shape = dst_sizes[i];
if (dst_shape.b != 1 && !op_def.batch_support) {
return InvalidArgumentError("op_def.batch_support must be enabled");
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
return InvalidArgumentError(
"Layout doesn't have Batch dimension, but shape.b != 1");
}
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
*creation_context.device, dst_shape,

View File

@ -26,7 +26,7 @@ namespace cl {
namespace {
std::string GetResizeCode(
const OperationDef& op_def,
const OperationDef& op_def, SamplingType sampling_type,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
@ -58,26 +58,35 @@ std::string GetResizeCode(
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
"return;\n";
}
c += " float2 f_coords = (float2)(X, Y) * scale_factor;\n";
c += " int4 st;\n";
c += " st.xy = (int2)(f_coords.x, f_coords.y);\n";
c += " st.zw = min(st.xy + (int2)(1, 1), border);\n";
c += " float2 t = f_coords - (float2)(st.x, st.y);\n";
if (op_def.IsBatchSupported()) {
c += " st.x = st.x * src_size.w + B;\n";
c += " st.z = st.z * src_size.w + B;\n";
c += " X = X * dst_size.w + B;\n";
if (sampling_type == SamplingType::NEAREST) {
c += " int2 coord = (int2)(X * scale_factor.x, Y * scale_factor.y);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * src_size.w + B;\n";
c += " X = X * src_size.w + B;\n";
}
c += " FLT4 r0 = " + src_tensor.ReadWHS("coord.x", "coord.y", "Z") + ";\n";
} else {
c += " float2 f_coords = (float2)(X, Y) * scale_factor;\n";
c += " int4 st;\n";
c += " st.xy = (int2)(f_coords.x, f_coords.y);\n";
c += " st.zw = min(st.xy + (int2)(1, 1), border);\n";
c += " float2 t = f_coords - (float2)(st.x, st.y);\n";
if (op_def.IsBatchSupported()) {
c += " st.x = st.x * src_size.w + B;\n";
c += " st.z = st.z * src_size.w + B;\n";
c += " X = X * src_size.w + B;\n";
}
c += " float4 src0 = " + src_tensor.ReadAsFloatWHS("st.x", "st.y", "Z") +
";\n";
c += " float4 src1 = " + src_tensor.ReadAsFloatWHS("st.z", "st.y", "Z") +
";\n";
c += " float4 src2 = " + src_tensor.ReadAsFloatWHS("st.x", "st.w", "Z") +
";\n";
c += " float4 src3 = " + src_tensor.ReadAsFloatWHS("st.z", "st.w", "Z") +
";\n";
c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
"t.y));\n";
}
c += " float4 src0 = " + src_tensor.ReadAsFloatWHS("st.x", "st.y", "Z") +
";\n";
c += " float4 src1 = " + src_tensor.ReadAsFloatWHS("st.z", "st.y", "Z") +
";\n";
c += " float4 src2 = " + src_tensor.ReadAsFloatWHS("st.x", "st.w", "Z") +
";\n";
c += " float4 src3 = " + src_tensor.ReadAsFloatWHS("st.z", "st.w", "Z") +
";\n";
c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
"t.y));\n";
const LinkingContext context{"r0", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHS("r0", "X", "Y", "Z");
@ -86,7 +95,7 @@ std::string GetResizeCode(
}
std::string GetResize3DCode(
const OperationDef& op_def,
const OperationDef& op_def, SamplingType sampling_type,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data",
@ -125,34 +134,48 @@ std::string GetResize3DCode(
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
"return;\n";
}
c += " float4 f_coords = (float4)(X, Y, Z, 0) * scale_factor;\n";
c += " int4 start = (int4)(f_coords.x, f_coords.y, f_coords.z, 0);\n";
c += " int4 end = min(start + (int4)(1, 1, 1, 0), border);\n";
c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
if (op_def.IsBatchSupported()) {
c += " start.x = start.x * batch_size + B;\n";
c += " end.x = end.x * batch_size + B;\n";
c += " X = X * batch_size + B;\n";
if (sampling_type == SamplingType::NEAREST) {
c += " int4 coord = (int4)(X * scale_factor.x, Y * scale_factor.y, Z * "
"scale_factor.z, 0);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * batch_size + B;\n";
c += " X = X * batch_size + B;\n";
}
c += " FLT4 r0 = " +
src_tensor.ReadWHDS("coord.x", "coord.y", "coord.z", "S") + ";\n";
} else {
c += " float4 f_coords = (float4)(X, Y, Z, 0) * scale_factor;\n";
c += " int4 start = (int4)(f_coords.x, f_coords.y, f_coords.z, 0);\n";
c += " int4 end = min(start + (int4)(1, 1, 1, 0), border);\n";
c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
if (op_def.IsBatchSupported()) {
c += " start.x = start.x * batch_size + B;\n";
c += " end.x = end.x * batch_size + B;\n";
c += " X = X * batch_size + B;\n";
}
c += " float4 src0 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "start.z", "S") +
";\n";
c += " float4 src1 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "start.z", "S") + ";\n";
c += " float4 src2 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src3 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src4 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src5 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src6 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "end.z", "S") + ";\n";
c += " float4 src7 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "end.z", "S") + ";\n";
c +=
" float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
c +=
" float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
}
c += " float4 src0 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "start.z", "S") + ";\n";
c += " float4 src1 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "start.z", "S") + ";\n";
c += " float4 src2 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src3 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src4 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src5 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src6 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "end.z", "S") + ";\n";
c += " float4 src7 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "end.z", "S") + ";\n";
c += " float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
c += " float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
const LinkingContext context{"r0", "X", "Y", "S"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHDS("r0", "X", "Y", "Z", "S");
@ -179,10 +202,7 @@ Resize& Resize::operator=(Resize&& operation) {
}
Status Resize::Compile(const CreationContext& creation_context) {
if (attr_.type != SamplingType::BILINEAR) {
return InternalError("Only bilinear sampling is currently supported");
}
const auto code = GetResizeCode(definition_, linked_operations_);
const auto code = GetResizeCode(definition_, attr_.type, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
@ -243,7 +263,8 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) {
}
Status Resize3D::Compile(const CreationContext& creation_context) {
const auto code = GetResize3DCode(definition_, linked_operations_);
const auto code =
GetResize3DCode(definition_, attr_.type, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);

View File

@ -93,6 +93,35 @@ TEST_F(OpenCLOperationTest, ResizeBilinearNonAligned) {
}
}
TEST_F(OpenCLOperationTest, ResizeNearest) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 2, 1);
src_tensor.data = {1.0f, 2.0f};
Resize2DAttributes attr;
attr.align_corners = false;
attr.new_shape = HW(2, 4);
attr.type = SamplingType::NEAREST;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-5f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
Resize operation = CreateResize(op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 4, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps),
{1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu