From 2bf52af68fc5f71ae445ce0e2bf35458b1a00e8d Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Tue, 31 Mar 2020 12:52:58 -0700 Subject: [PATCH] Fixed accumulator precision for generic DepthWise implementation. Removed inlined constants for kernel sizes. Added test. PiperOrigin-RevId: 304026983 Change-Id: I4f9eac57ba1ec4e6f929d3ab7c9176f0d6f3b4ce --- .../gpu/metal/kernels/depthwise_conv.cc | 130 +++++++++--------- .../gpu/metal/kernels/depthwise_conv_test.mm | 39 ++++++ 2 files changed, 105 insertions(+), 64 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc index 9fa627bcac2..6c26a87c267 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc @@ -475,91 +475,93 @@ std::vector DepthWiseConvolution( std::string shader_source = R"( #include using namespace metal; - constant int kernel_x = $0; - constant int kernel_y = $1; struct uniforms { - int4 stride; - int4 padding; - int4 dilation; - int4 size; + int4 src_size; + int4 dst_size; + int2 stride; + int2 padding; + int2 dilation; + int2 kernel_size; int4 channel_multiplier; }; - $$0 + $0 kernel void ComputeFunction( - $$1 + $1 uint tid[[thread_index_in_threadgroup]], uint3 gid[[thread_position_in_grid]]) { - const bool outside = static_cast(gid.x) >= params.size.z || - static_cast(gid.y) >= params.size.w; - if (outside) { - return; - } - device FLT4* temp = filters + gid.z * kernel_y * kernel_x; - float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f); + int dst_x = static_cast(gid.x); + int dst_y = static_cast(gid.y); + int dst_z = static_cast(gid.z); - for(int ky = 0; ky < kernel_y; ++ky) { - for(int kx = 0; kx < kernel_x; ++kx) { - int2 coords = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dilation.xy - - params.padding.xy; - const bool outside = coords.x < 0 || coords.y < 0 || - coords.x >= params.size.x || coords.y >= params.size.y; - if (outside) continue; + if (dst_x >= U.dst_size.x || dst_y >= U.dst_size.y) return; + + device FLT4* temp = filters + dst_z * U.kernel_size.x * U.kernel_size.y; + ACCUM_FLT4 sum0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + + int src_x = dst_x * U.stride.x + U.padding.x; + int src_y = dst_y * U.stride.y + U.padding.y; + + for(int ky = 0; ky < U.kernel_size.y; ++ky) { + int yc = ky * U.dilation.y + src_y; + if (yc < 0 || yc >= U.src_size.y) continue; + for(int kx = 0; kx < U.kernel_size.x; ++kx) { + int xc = kx * U.dilation.x + src_x; + if (xc < 0 || xc >= U.src_size.x) continue; )"; if (channels_multiplier == 1) { shader_source += R"( - const int src_layer = gid.z; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src_modified = src_buffer[src_index]; + int src_layer = dst_z; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src_modified = src_buffer[src_index]; )"; } else if (channels_multiplier == 2) { shader_source += R"( - const int src_layer = gid.z / 2; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; - const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw; - const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y); + int src_layer = dst_z / 2; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; + FLT2 t0 = dst_z % 2 == 0 ? src.xy : src.zw; + FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y); )"; } else if (channels_multiplier == 4) { shader_source += R"( - const int src_layer = gid.z / 4; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; - const FLT t0 = src[gid.z % 4]; - const FLT4 src_modified = FLT4(t0, t0, t0, t0); + int src_layer = dst_z / 4; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; + FLT t0 = src[dst_z % 4]; + FLT4 src_modified = FLT4(t0, t0, t0, t0); )"; } else { shader_source += R"( - const int src_layer = gid.z / params.channel_multiplier.x; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; + int src_layer = dst_z / U.channel_multiplier.x; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; FLT4 src_modified; - const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4; - src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x]; - src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x]; - src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x]; - src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x]; + const int src_layer_offset = (dst_z % U.channel_multiplier.x) * 4; + src_modified.x = src[(src_layer_offset + 0) / U.channel_multiplier.x]; + src_modified.y = src[(src_layer_offset + 1) / U.channel_multiplier.x]; + src_modified.z = src[(src_layer_offset + 2) / U.channel_multiplier.x]; + src_modified.w = src[(src_layer_offset + 3) / U.channel_multiplier.x]; )"; } shader_source += R"( - sum0 += float4(src_modified * temp[ky * kernel_x + kx]); + sum0 += TO_ACCUM4_TYPE(src_modified * temp[ky * U.kernel_size.x + kx]); } } - FLT4 res = FLT4(sum0 + float4(biases[gid.z])); - const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x); + FLT4 res = FLT4(sum0) + biases[dst_z]; + const int linear_index = (dst_z * U.dst_size.y + dst_y) * U.dst_size.x + dst_x; FLT4 value = res; - $$2 - output_buffer[linear_index] = value; + $2 + dst_buffer[linear_index] = value; } )"; - desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w, - attr.weights.shape.h); + desc->shader_source = shader_source; desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, }; desc->output_buffer = { - output_id, "device FLT4* output_buffer", + output_id, "device FLT4* dst_buffer", [input_id, attr](const std::map& buffers) { auto out_shape = CalculateOutputShape(buffers.find(input_id)->second, attr); @@ -577,27 +579,27 @@ std::vector DepthWiseConvolution( }; desc->uniform_buffers = { - {"constant uniforms& params", + {"constant uniforms& U", [input_id, output_id, attr](const std::map& buffers) { const auto& dimension = buffers.find(input_id)->second; const auto& output_dimension = buffers.find(output_id)->second; std::vector uniform_params{ - attr.strides.w, - attr.strides.h, - 1, - 1, - attr.padding.prepended.w, - attr.padding.prepended.h, - 1, - 1, - attr.dilations.w, - attr.dilations.h, - 1, - 1, dimension.w, dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + 0, output_dimension.w, output_dimension.h, + IntegralDivideRoundUp(output_dimension.c, 4), + 0, + attr.strides.w, + attr.strides.h, + -attr.padding.prepended.w, + -attr.padding.prepended.h, + attr.dilations.w, + attr.dilations.h, + attr.weights.shape.w, + attr.weights.shape.h, attr.weights.shape.o, 0, 0, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm index d76507253a9..dcf550f7868 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm @@ -167,4 +167,43 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } +- (void)testShape2x2Kernel2x2 { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 1); + + DepthwiseConvolution2DAttributes attr; + Tensor bias; + bias.shape.v = 1; + bias.id = 1; + bias.data = {0}; + attr.bias = std::move(bias); + + Tensor weights; + weights.shape = OHWI(1, 2, 2, 1); + weights.id = 1; + weights.data = {1, 2, 3, 4}; + + attr.weights = std::move(weights); + + attr.dilations = HW(1, 1); + attr.padding.prepended = HW(0, 0); + attr.padding.appended = HW(1, 1); + attr.strides = HW(1, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 3; + output.shape = BHWC(1, 2, 2, 1); + + SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input}, + {output}); + XCTAssertTrue(model.PopulateTensor(0, {1, 4, 9, 16})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({100, 52, 41, 16}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + @end