diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 5eb7d284ad1..4c0af17090e 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -51,39 +51,6 @@ namespace tflite { namespace gpu { namespace metal { namespace { - -std::vector SelectConvolution( - const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { - // Special precise version, in case we cover dst_shape poorly with standard - // work group size. - auto gpu_type = GetGpuType(); - bool a11_12 = gpu_type == GpuType::kA11 || gpu_type == GpuType::kA12; - const auto dst_shape = graph.FindOutputs(id)[0]->tensor.shape; - if (GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { - // Special version for PowerVR >= IPhone6S/SE - // Metal has bad driver for PowerVR in IPhone6, so for Iphone6 we should use - // default kernel with shared memory. - if ((gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10) && - CheckConvolutionPrecise1x1Support(attr)) { - return ConvolutionPrecise1x1PowerVR(id, input_id, output_id, attr, - options); - } - if (a11_12 && GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { - return ConvolutionPrecise(id, input_id, output_id, attr, options); - } - } - if (a11_12) { - if (CheckConvolution1x1Support(attr)) { - return Convolution1x1(id, input_id, output_id, attr, options); - } else { - return ConvolutionGeneric(id, input_id, output_id, attr, options); - } - } else { - return Convolution(id, input_id, output_id, attr, options); - } -} - std::vector SelectDepthWiseConv( int id, ValueId input_id, ValueId output_id, const DepthwiseConvolution2DAttributes& attr, @@ -182,12 +149,14 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, input_shapes); break; } - case OperationType::CONVOLUTION_2D: - *tasks = SelectConvolution( - graph, node_id, inputs[0], outputs[0], - absl::any_cast(node->operation.attributes), - options); + case OperationType::CONVOLUTION_2D: { + const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape; + auto attr = + absl::any_cast(node->operation.attributes); + *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape, + attr, options); break; + } case OperationType::CONVOLUTION_TRANSPOSED: *tasks = SelectConvolutionTransposed( node_id, inputs[0], outputs[0], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 70d882bb05b..f22fe642ca3 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -127,6 +127,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc index 60ac73abfaa..73f152412a9 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc @@ -30,519 +30,181 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { namespace metal { + +enum class WeightsUploadType { + LOCAL_MEM_BY_THREADS, + GLOBAL_MEM, + CONSTANT_MEM, +}; + +struct ConvParams { + int3 block_size; + int3 work_group_size; + int3 work_group_launch_order; + int src_depth_loop_size; + bool need_src_loop = true; + bool need_dst_loop = true; + bool linear_wh; + bool linear_whs; + WeightsUploadType weights_upload_type; + bool x_kernel_is_1; + bool y_kernel_is_1; +}; + namespace { int GetNumOutputSlices(int dst_channels) { const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); - if (dst_depth % 4 == 0) { + if (dst_depth % 4 == 0 || dst_depth >= 16) { return 4; - } else if (dst_depth % 2 == 0) { + } else if (dst_depth % 2 == 0 || dst_depth >= 4) { return 2; } else { return 1; } } -int GetSrcBatchSize(int dst_channels) { - const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); - if (dst_depth % 4 == 0) { - return 2; - } else if (dst_depth % 2 == 0) { - return 4; - } else { - return 8; - } -} - -std::string GetValuesDeclarationPart(int num_output_slices, bool is_1x1) { - std::string code; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - float4 sum$0 = float4(0.0f, 0.0f, 0.0f, 0.0f); - )", - d); - } - if (is_1x1) { - code += absl::Substitute(R"( - threadgroup FLT4 temp[32]; - device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset; - )", - num_output_slices * 4); - } else { - code += absl::Substitute(R"( - threadgroup FLT4 temp[32]; - device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset * - kernel_y * kernel_x; - )", - num_output_slices * 4); - } - return code; -} - -std::string GetLocalMemoryUploadPart() { - std::string code = R"( - BARRIER(mem_flags::mem_none); - temp[tid] = f_offseted[tid]; - f_offseted += 32; - BARRIER(mem_flags::mem_threadgroup); - )"; - return code; -} - -std::string GetSummationPart(int num_output_slices, int index) { - std::string code = R"( - { - const FLT4 src = src_buffer[src_address]; - src_address += params.dilation_layer_offsets.z; - )"; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - sum$6.x += dot(temp[$0 * $1 + $2], src) * multiplier; - sum$6.y += dot(temp[$0 * $1 + $3], src) * multiplier; - sum$6.z += dot(temp[$0 * $1 + $4], src) * multiplier; - sum$6.w += dot(temp[$0 * $1 + $5], src) * multiplier; - )", - index, num_output_slices * 4, d * 4 + 0, d * 4 + 1, - d * 4 + 2, d * 4 + 3, d); - } - code += "}"; - return code; -} - -std::string GetBiasReadingPart(int num_output_slices) { - std::string code = absl::Substitute(R"( - { - gid.z = (gid.z + params.z_offset.x) * $0; - BARRIER(mem_flags::mem_none); - if (tid < $0) { - temp[tid] = biases[gid.z + tid]; - } - BARRIER(mem_flags::mem_threadgroup); - if (outside) { - return; - } - })", - num_output_slices); - return code; -} - -std::string GetWritingPart(int num_output_slices) { - std::string code; - for (int d = 0; d < num_output_slices; ++d) { - code += absl::Substitute(R"( - { - int dst_address = int(gid.y) * params.size.z + int(gid.x); - FLT4 value = FLT4(sum$0) + temp[$0]; - const int linear_index = gid.z * params.dilation_layer_offsets.w + dst_address; - $$2 - dst_buffer[linear_index + params.z_offset.y] = value; - gid.z += 1; - })", - d); - } - return code; -} - -std::string GetKernelForConv(const Convolution2DAttributes& params) { - const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - const bool is_1x1 = - params.weights.shape.w == 1 && params.weights.shape.h == 1; - const bool is_strided = params.strides.w > 1 || params.strides.h > 1; - const int src_group_size = GetSrcBatchSize(params.weights.shape.o); - - const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - const int src_groups = src_depth / src_group_size; - const int src_depth_aligned = AlignByN(src_depth, src_group_size); - const int reminder_src_depth = src_depth - src_groups * src_group_size; - - code = absl::Substitute(R"( - #include - using namespace metal; - constant int src_depth_groups = $0; - constant int src_offset = $1; - constant int kernel_x = $2; - constant int kernel_y = $3; - struct uniforms { - int4 stride_padding; - int4 dilation_layer_offsets; - int4 size; - int4 z_offset; - }; - $$0 - kernel void ComputeFunction( - $$1 - uint tid[[thread_index_in_threadgroup]], - uint3 gid[[thread_position_in_grid]]) - { - const bool outside = static_cast(gid.x) >= params.size.z || - static_cast(gid.y) >= params.size.w; - )", - src_groups, src_depth_aligned, params.weights.shape.w, - params.weights.shape.h); - code += GetValuesDeclarationPart(num_output_slices, is_1x1); - - if (!is_1x1) { - code += R"( - for(int ky = 0; ky < kernel_y; ++ky) { - for(int kx = 0; kx < kernel_x; ++kx) { - int2 coords = int2(gid.xy) * params.stride_padding.xy + int2(kx, ky) * - params.dilation_layer_offsets.xy - params.stride_padding.zw; - const bool el_outside = coords.x < 0 || coords.y < 0 || coords.x >= params.size.x || - coords.y >= params.size.y; - const FLT multiplier = el_outside ? 0.0f : 1.0f; - )"; - } else { - code += "const FLT multiplier = 1.0f;\n"; - code += "int2 coords = int2(gid.xy)"; - if (is_strided) { - code += " * params.stride_padding.xy"; - } - code += ";\n"; - } - code += R"( - coords = clamp(coords, int2(0, 0), int2(params.size.x - 1, params.size.y - 1)); - int src_address = coords.y * params.size.x + coords.x; - for(int s = 0; s < src_depth_groups; ++s) { - )"; - code += GetLocalMemoryUploadPart(); - for (int sub_s = 0; sub_s < src_group_size; ++sub_s) { - code += GetSummationPart(num_output_slices, sub_s); - } - code += R"( - } - )"; - if (reminder_src_depth != 0) { - code += GetLocalMemoryUploadPart(); - for (int sub_s = 0; sub_s < reminder_src_depth; ++sub_s) { - code += GetSummationPart(num_output_slices, sub_s); - } - } - if (!is_1x1) { - code += R"( - } - } - )"; - } - code += GetBiasReadingPart(num_output_slices); - code += GetWritingPart(num_output_slices); - code += " }"; - return code; -} - -// Reorder weights to make the weights memory access pattern cache friendly for -// GPU -std::vector ReorderWeightsForConvShared( - const Convolution2DAttributes& params) { - const int dst_batch_size = GetNumOutputSlices(params.weights.shape.o) * 4; - const int src_batch_size = GetSrcBatchSize(params.weights.shape.o); - BHWC input_dimensions{params.weights.shape.o, params.weights.shape.h, - params.weights.shape.w, params.weights.shape.i}; - const int gpu_simd_size = dst_batch_size * src_batch_size; - const int weights_width = AlignByN(input_dimensions.c, gpu_simd_size); - const int weights_height = AlignByN(input_dimensions.b, dst_batch_size); - const int weights_channels = params.weights.shape.w * params.weights.shape.h; - const int weights_aligned_size = - weights_width * weights_height * weights_channels; - std::vector weights_reordered(weights_aligned_size); - float* destination = weights_reordered.data(); - const int dst_groups = - IntegralDivideRoundUp(input_dimensions.b, dst_batch_size); - const int src_sub_groups = - IntegralDivideRoundUp(input_dimensions.c, 4 * src_batch_size); - for (int group = 0; group < dst_groups; ++group) { - for (int y = 0; y < params.weights.shape.h; ++y) { - for (int x = 0; x < params.weights.shape.w; ++x) { - for (int sub_group = 0; sub_group < src_sub_groups; ++sub_group) { - for (int s = 0; s < src_batch_size; ++s) { - for (int d = 0; d < dst_batch_size; ++d) { - int output_index = group * dst_batch_size + d; - for (int i = 0; i < 4; ++i) { - int input_index = (sub_group * src_batch_size + s) * 4 + i; - if (input_index >= input_dimensions.c || - output_index >= input_dimensions.b) { - // Padding with zero - *destination++ = 0.0f; - } else { - int linear_index = - input_index + - input_dimensions.c * - (x + input_dimensions.w * - (y + input_dimensions.h * output_index)); - *destination++ = params.weights.data[linear_index]; - } - } - } - } - } - } - } - } - return weights_reordered; -} - -std::vector GetUniformBufferForConvShared( - const BHWC& input_dimensions, const BHWC& output_dimensions, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - params.strides.w, - params.strides.h, - params.padding.prepended.w, - params.padding.prepended.h, - params.dilations.w, - params.dilations.h, - input_dimensions.w * input_dimensions.h, - output_dimensions.w * output_dimensions.h, - input_dimensions.w, - input_dimensions.h, - output_dimensions.w, - output_dimensions.h, - // TODO(chirkov): use z_offset for concat table optimization - /*z_offset.x=*/0, - /*z_offset.y=*/0, - /*z_offset.z=*/0, - /*z_offset.w=*/0, - }; - return GetByteBuffer(uniform_params); -} - -std::string GetKernelForConv1x1(const Convolution2DAttributes& params, - int z_out) { - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - std::string channels[4] = {"x", "y", "z", "w"}; - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 stride_padding; - int4 kernel_dilation; - uint4 work_group_size; +struct GlobalIdsParams { + std::vector global_ids; + std::vector group_ids; + std::vector local_sizes; + std::vector local_ids; + int3 block_size; + int3 launch_order; + bool linear_wh; + bool linear_whs; + std::string task_size_w; // must be filled if linear_wh or linear_whs enabled + std::string task_size_wh; // must be filled if linear_whs enabled }; -$0 -kernel void ComputeFunction( - $1 - uint3 group_id[[threadgroup_position_in_grid]], - uint3 tid3d[[thread_position_in_threadgroup]]) -{ - int gid_x = group_id.y * params.work_group_size.x + tid3d.x; - int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) << 1u; - )"; - code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + - std::to_string(z_out) + "u;\n"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; +std::string GlobalIdsGen(const GlobalIdsParams& params) { + std::string c; + int3 launch_remap; + launch_remap[params.launch_order.x] = 0; + launch_remap[params.launch_order.y] = 1; + launch_remap[params.launch_order.z] = 2; + if (params.linear_whs) { + c += " int linear_whs = " + params.global_ids[0] + ";\n"; + c += " int Z = (linear_whs / " + params.task_size_wh + ") * " + + std::to_string(params.block_size.z) + ";\n"; + c += " int linear_wh = linear_whs % " + params.task_size_wh + ";\n"; + c += " int Y = (linear_wh / " + params.task_size_w + ") * " + + std::to_string(params.block_size.y) + ";\n"; + c += " int X = (linear_wh % " + params.task_size_w + ") * " + + std::to_string(params.block_size.x) + ";\n"; + } else if (params.linear_wh) { + if (params.launch_order.x == 0) { + c += " int linear_wh = " + params.global_ids[0] + ";\n"; + } else { + c += " int linear_wh = " + params.group_ids[launch_remap.x] + " * " + + params.local_sizes[0] + " + " + params.local_ids[0] + ";\n"; + } + c += " int Y = (linear_wh / " + params.task_size_w + ") * " + + std::to_string(params.block_size.y) + ";\n"; + c += " int X = (linear_wh % " + params.task_size_w + ") * " + + std::to_string(params.block_size.x) + ";\n"; + if (params.launch_order.y == 1) { + c += " int Z = " + params.global_ids[1] + " * " + + std::to_string(params.block_size.z) + ";\n"; + } else { + c += " int Z = (" + params.group_ids[launch_remap.y] + " * " + + params.local_sizes[1] + " + " + params.local_ids[1] + ") * " + + std::to_string(params.block_size.z) + ";\n"; + } + } else { + if (params.launch_order.x == 0) { + c += " int X = " + params.global_ids[0] + " * " + + std::to_string(params.block_size.x) + ";\n"; + } else { + c += " int X = (" + params.group_ids[launch_remap.x] + " * " + + params.local_sizes[0] + " + " + params.local_ids[0] + ") * " + + std::to_string(params.block_size.x) + ";\n"; + } + if (params.launch_order.y == 1) { + c += " int Y = " + params.global_ids[1] + " * " + + std::to_string(params.block_size.y) + ";\n"; + } else { + c += " int Y = (" + params.group_ids[launch_remap.y] + " * " + + params.local_sizes[1] + " + " + params.local_ids[1] + ") * " + + std::to_string(params.block_size.y) + ";\n"; + } + if (params.launch_order.z == 2) { + c += " int Z = " + params.global_ids[2] + " * " + + std::to_string(params.block_size.z) + ";\n"; + } else { + c += " int Z = (" + params.group_ids[launch_remap.z] + " * " + + params.local_sizes[2] + " + " + params.local_ids[2] + ") * " + + std::to_string(params.block_size.z) + ";\n"; + } } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; - - int y0 = clamp(gid_y, 0, params.src_size.y - 1); - int y1 = clamp(gid_y + 1, 0, params.src_size.y - 1); - int x0 = clamp(gid_x, 0, params.src_size.x - 1); - - int s = 0; - - device FLT4* src_loc_0 = src_buffer + y0 * params.src_size.x + x0; - device FLT4* src_loc_1 = src_buffer + y1 * params.src_size.x + x0; - do { - FLT4 src_0 = *src_loc_0; - FLT4 src_1 = *src_loc_1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; - )"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; - } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; - const int offset_1 = offset_0 + params.dst_size.x; - bool y0_in = gid_y < params.dst_size.y; - bool y1_in = gid_y + 1 < params.dst_size.y; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - } - code += R"( - if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { - return; - } - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (y0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (y1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; - } - code += " }\n"; - return code; + return c; } -std::string GetKernelForConvGeneric(const Convolution2DAttributes& params, - int z_out) { - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - std::string channels[4] = {"x", "y", "z", "w"}; - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 stride_padding; - int4 kernel_dilation; - uint4 work_group_size; -}; -$0 - -kernel void ComputeFunction( - $1 - uint3 group_id[[threadgroup_position_in_grid]], - uint3 tid3d[[thread_position_in_threadgroup]]) -{ - int gid_x = group_id.y * params.work_group_size.x + tid3d.x; - int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) * 2; - )"; - code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + - std::to_string(z_out) + "u;\n"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; +std::string GenerateUploadByThreads(const std::string& local_ptr_name, + const std::string& global_ptr_name, + const std::string& global_offset_name, + const std::string& lid_name, + int total_work_items, + int elements_to_upload) { + std::string c; + std::string offset = + global_offset_name.empty() ? "" : global_offset_name + " + "; + const int groups = elements_to_upload / total_work_items; + const int reminder = elements_to_upload % total_work_items; + for (int i = 0; i < groups; ++i) { + c += " " + local_ptr_name + "[" + lid_name + " + " + + std::to_string(total_work_items * i) + "] = " + global_ptr_name + "[" + + offset + lid_name + " + " + std::to_string(total_work_items * i) + + "];\n"; } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * params.kernel_dilation.x * params.kernel_dilation.y; - - int y0 = gid_y * params.stride_padding.y + params.stride_padding.w; - int y1 = (gid_y + 1) * params.stride_padding.y + params.stride_padding.w; - int x0 = gid_x * params.stride_padding.x + params.stride_padding.z; - - int y = 0; - do { - int coord_y0 = y * params.kernel_dilation.w + y0; - int coord_y1 = y * params.kernel_dilation.w + y1; - bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; - bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; - coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); - coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); - int x = 0; - do { - int coord_x0 = x * params.kernel_dilation.z + x0; - bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; - coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); - FLT m0 = !(y0_out || x0_out); - FLT m1 = !(y1_out || x0_out); - int s = 0; - device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; - device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x0; - do { - FLT4 src_0 = *src_loc_0 * m0; - FLT4 src_1 = *src_loc_1 * m1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; - )"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; + if (reminder != 0) { + c += " if (" + lid_name + " < " + std::to_string(reminder) + ") {\n"; + c += " " + local_ptr_name + "[" + lid_name + " + " + + std::to_string(total_work_items * groups) + "] = " + global_ptr_name + + "[" + offset + lid_name + " + " + + std::to_string(total_work_items * groups) + "];\n"; + c += " }\n"; } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - x++; - } while (x < params.kernel_dilation.x); - y++; - } while (y < params.kernel_dilation.y); - const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; - const int offset_1 = offset_0 + params.dst_size.x; - bool p0_in = gid_x < params.dst_size.x && gid_y < params.dst_size.y; - bool p1_in = gid_x < params.dst_size.x && gid_y + 1 < params.dst_size.y; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - } - code += R"( - if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { - return; - } - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (p0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (p1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; - } - code += " }\n"; - return code; + return c; } -std::string GetKernelForConvPrecise(int z_out) { +std::string GenerateConvolution(const ConvParams& params) { + GlobalIdsParams ids_params; + ids_params.group_ids = {"group_id.x", "group_id.y", "group_id.z"}; + ids_params.global_ids = {"ugid.x", "ugid.y", "ugid.z"}; + ids_params.local_ids = {"tid3d.x", "tid3d.y", "tid3d.z"}; + ids_params.local_sizes = {"params.work_group_size.x", + "params.work_group_size.y", + "params.work_group_size.z"}; + ids_params.linear_wh = params.linear_wh; + ids_params.task_size_w = "params.task_sizes.x"; + ids_params.task_size_wh = "params.task_sizes.y"; + ids_params.linear_whs = params.linear_whs; + ids_params.block_size = params.block_size; + ids_params.launch_order = params.work_group_launch_order; + + std::string addr_space = + params.weights_upload_type == WeightsUploadType::CONSTANT_MEM ? "constant" + : "device"; + const bool use_local_mem = + params.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS; + const int local_mem_size = + params.block_size.z * 4 * params.src_depth_loop_size; + + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + std::string channels[4] = {"x", "y", "z", "w"}; - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - code += R"( + std::string c; + c.reserve(16 * 1024); // Reserve large enough buffer. + c += R"( #include using namespace metal; @@ -551,209 +213,298 @@ struct uniforms { int4 dst_size; int4 stride_padding; int4 kernel_dilation; - int4 slices; + int4 task_sizes; + uint4 work_group_size; }; $0 kernel void ComputeFunction( $1 + uint tid[[thread_index_in_threadgroup]], + uint3 group_id[[threadgroup_position_in_grid]], + uint3 tid3d[[thread_position_in_threadgroup]], uint3 ugid[[thread_position_in_grid]]) { - int linear_id = ugid.x; - int gid_z = linear_id / params.slices.y; - int linear_xy = (linear_id - gid_z * params.slices.y) << 1; - )"; - code += " gid_z *= " + std::to_string(z_out) + ";\n"; - code += R"( - int gid_y0 = linear_xy / params.slices.x; - int gid_x0 = linear_xy - gid_y0 * params.slices.x; - linear_xy += 1; - int gid_y1 = linear_xy / params.slices.x; - int gid_x1 = linear_xy - gid_y1 * params.slices.x; - - if (gid_z >= params.dst_size.w) return; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * - params.kernel_dilation.x * params.kernel_dilation.y; - - int y0 = gid_y0 * params.stride_padding.y + params.stride_padding.w; - int y1 = gid_y1 * params.stride_padding.y + params.stride_padding.w; - int x0 = gid_x0 * params.stride_padding.x + params.stride_padding.z; - int x1 = gid_x1 * params.stride_padding.x + params.stride_padding.z; )"; - code += R"( - int y = 0; - do { - int coord_y0 = y * params.kernel_dilation.w + y0; - int coord_y1 = y * params.kernel_dilation.w + y1; - bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; - bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; - coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); - coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); - int x = 0; - do { - int coord_x0 = x * params.kernel_dilation.z + x0; - int coord_x1 = x * params.kernel_dilation.z + x1; - bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; - bool x1_out = coord_x1 < 0 || coord_x1 >= params.src_size.x; - coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); - coord_x1 = clamp(coord_x1, 0, params.src_size.x - 1); - FLT m0 = !(y0_out || x0_out); - FLT m1 = !(y1_out || x1_out); - device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; - device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x1; - int s = 0; - do { - FLT4 src_0 = *src_loc_0 * m0; - FLT4 src_1 = *src_loc_1 * m1; - src_loc_0 += params.src_size.z; - src_loc_1 += params.src_size.z; -)"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_1);\n"; + c += GlobalIdsGen(ids_params); + c += " if (Z >= params.dst_size.w) return;\n"; + if (!use_local_mem && !params.linear_whs) { + c += " if (X >= params.dst_size.x || Y >= params.dst_size.y) return;\n"; + } + for (int z = 0; z < params.block_size.z; ++z) { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_i = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += + " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + } + } + auto for_every_yx = + [&](std::function + lambda) { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + c += lambda(s_yx, s_x, s_y, x, y) + "\n"; + } + } + }; + if (!use_filters_constants) { + std::string kern_x = + params.x_kernel_is_1 ? "" : " * params.kernel_dilation.x"; + std::string kern_y = + params.y_kernel_is_1 ? "" : " * params.kernel_dilation.y"; + std::string dst_offset = + params.need_dst_loop ? " + Z * 4 * params.src_size.w" : ""; + if (!params.need_dst_loop) { + c += " " + addr_space + " FLT4* tmp = filters;\n"; + } else { + c += " " + addr_space + + " FLT4* tmp = filters + Z * 4 * params.src_size.w" + kern_x + + kern_y + ";\n"; + } + } + if (!params.x_kernel_is_1) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int x" + s_x + " = (X + " + s_x + + ") * params.stride_padding.x + params.stride_padding.z;\n"; + } + } + if (!params.y_kernel_is_1) { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int y" + s_y + " = (Y + " + s_y + + ") * params.stride_padding.y + params.stride_padding.w;\n"; + } + } + if (use_local_mem) { + c += " threadgroup FLT4 weights_cache[" + std::to_string(local_mem_size) + + "];\n"; + } + if (!params.y_kernel_is_1) { + c += " int y = 0;\n"; + c += " do {\n"; + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int c_y" + s_y + " = y * params.kernel_dilation.w + y" + s_y + + ";\n"; + c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y + + " >= params.src_size.y;\n"; + c += " c_y" + s_y + " = clamp(c_y" + s_y + + ", 0, params.src_size.y - 1);\n"; + } + } else { + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + c += " int c_y" + s_y + " = clamp(Y + " + s_y + + ", 0, params.src_size.y - 1);\n"; + } + } + if (!params.x_kernel_is_1) { + c += " int x = 0;\n"; + c += " do {\n"; + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int c_x" + s_x + " = x * params.kernel_dilation.z + x" + s_x + + ";\n"; + c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x + + " >= params.src_size.x;\n"; + c += " c_x" + s_x + " = clamp(c_x" + s_x + + ", 0, params.src_size.x - 1);\n"; + } + } else { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + c += " int c_x" + s_x + " = clamp(X + " + s_x + + ", 0, params.src_size.x - 1);\n"; + } + } + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + if (!params.y_kernel_is_1 && !params.x_kernel_is_1) { + c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x + "_out);\n"; + } else if (!params.y_kernel_is_1) { + c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n"; + } else if (!params.x_kernel_is_1) { + c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n"; + } + } + } + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + c += " device FLT4* src_loc_" + s_yx + " = src_buffer + c_y" + s_y + + " * params.src_size.x + c_x" + s_x + ";\n"; + } + } + c += " int s = 0;\n"; + if (params.need_src_loop) { + c += " do {\n"; + } + if (use_local_mem) { + const int total_work_items = params.work_group_size.x * + params.work_group_size.y * + params.work_group_size.z; + c += " BARRIER(mem_flags::mem_none);\n"; + c += GenerateUploadByThreads("weights_cache", "tmp", + /*global_offset_name*/ "", "tid", + total_work_items, local_mem_size); + c += " BARRIER(mem_flags::mem_threadgroup);\n"; + } + auto declare_src = [&]() { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + c += " FLT4 src" + s_yx + ";\n"; + } + } + }; + auto read_src = [&]() { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + if (!params.y_kernel_is_1 || !params.x_kernel_is_1) { + c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx + ";\n"; + } else { + c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n"; + } + } + } + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_yx = std::to_string(y) + std::to_string(x); + c += " src_loc_" + s_yx + " += params.src_size.z;\n"; + } + } + }; + auto conv_core = [&](int offset) { + std::string name = use_local_mem ? "weights_cache" : "tmp"; + if (use_filters_constants) { + name = "filters"; + } + for (int z = 0; z < params.block_size.z; ++z) { + for (int ch = 0; ch < 4; ++ch) { + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + std::string s_id = std::to_string(y) + std::to_string(x); + std::string r_id = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += " r" + r_id + "." + channels[ch] + " += dot(" + name + "[" + + std::to_string(z * 4 + ch + offset) + "], src" + s_id + ");\n"; + } + } + } + } + }; + declare_src(); + read_src(); + c += " s += 1;\n"; + conv_core(0); + for (int i = 1; i < params.src_depth_loop_size; ++i) { + read_src(); + conv_core(i * params.block_size.z * 4); + c += " s += 1;\n"; + } + if (!use_filters_constants) { + c += " tmp += " + + std::to_string(params.block_size.z * 4 * params.src_depth_loop_size) + + ";\n"; + } + if (params.need_src_loop) { + c += " } while (s < params.src_size.w);\n"; + } + if (!params.x_kernel_is_1) { + c += " x++;\n"; + c += " } while (x < params.kernel_dilation.x);\n"; + } + if (!params.y_kernel_is_1) { + c += " y++;\n"; + c += " } while (y < params.kernel_dilation.y);\n"; } - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - x++; - } while (x < params.kernel_dilation.x); - y++; - } while (y < params.kernel_dilation.y); - const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; - const int offset_1 = gid_z * params.dst_size.z + gid_y1 * params.dst_size.x + gid_x1; - bool p0_in = gid_x0 < params.dst_size.x && gid_y0 < params.dst_size.y; - bool p1_in = gid_x1 < params.dst_size.x && gid_y1 < params.dst_size.y; + if (use_local_mem && !params.linear_whs) { + c += " if (X >= params.dst_size.x || Y >= params.dst_size.y) return;\n"; + } - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; - code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + for_every_yx([](const std::string& s_yx, const std::string& s_x, + const std::string& s_y, int x, int y) { + return " const int offset_" + s_yx + " = Z * params.dst_size.z + (Y + " + + s_y + ") * params.dst_size.x + X + " + s_x + ";"; + }); + + std::string bias_name = "biases"; + if (params.need_dst_loop) { + c += " device FLT4* bias_loc = biases + Z;\n"; + bias_name = "bias_loc"; } - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " if (p0_in) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " if (p1_in) {\n"; - code += " FLT4 value = FLT4(l" + s_i + ");\n"; - code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + - ";\n"; - code += " uint3 gid = uint3(gid_x1, gid_y1, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - code += " }\n"; + for (int y = 0; y < params.block_size.y; ++y) { + for (int x = 0; x < params.block_size.x; ++x) { + for (int z = 0; z < params.block_size.z; ++z) { + std::string r_id = + std::to_string(z) + std::to_string(y) + std::to_string(x); + c += " r" + r_id + " += TO_ACCUM4_TYPE(" + bias_name + "[" + + std::to_string(z) + "]);\n"; + } + } } - code += " }\n"; - return code; + for (int z = 0; z < params.block_size.z; ++z) { + const std::string s_z = std::to_string(z); + c += " if (Z + " + s_z + " < params.dst_size.w) {\n"; + for (int y = 0; y < params.block_size.y; ++y) { + const std::string s_y = std::to_string(y); + for (int x = 0; x < params.block_size.x; ++x) { + const std::string s_x = std::to_string(x); + const std::string s_yx = s_y + s_x; + const std::string s_zyx = s_z + s_yx; + bool need_check_x = x >= 1; + bool need_check_y = y >= 1; + std::string check; + if (need_check_x) { + check += "(X + " + s_x + ") < params.dst_size.x"; + } + if (need_check_y) { + check += check.empty() ? "" : " && "; + check += "(Y + " + s_y + ") < params.dst_size.y"; + } + if (!check.empty()) { + c += " if (" + check + ") {\n"; + } else { + c += " {\n"; + } + c += " FLT4 value = FLT4(r" + s_zyx + ");\n"; + c += " int linear_index = offset_" + s_yx + + " + params.dst_size.z * " + s_z + ";\n"; + c += " uint3 gid = uint3(X + " + s_x + ", Y + " + s_y + ", Z + " + + s_z + ");\n"; + c += " $2\n"; + c += " dst_buffer[linear_index] = value;\n"; + c += " }\n"; + } + } + c += " }\n"; + } + c += "}\n"; + return c; } -std::string GetKernelForConvPrecise1x1PowerVR(int z_out) { - std::string channels[4] = {"x", "y", "z", "w"}; - std::string code; - code.reserve(16 * 1024); // Reserve large enough buffer. - code += R"( -#include -using namespace metal; - -struct uniforms { - int4 src_size; - int4 dst_size; - int4 slices; - int4 dummy0; -}; -$0 - -kernel void ComputeFunction( - $1 - uint3 ugid[[thread_position_in_grid]]) -{ - int linear_id = ugid.x; - int gid_z = linear_id / params.slices.y; - int linear_xy = linear_id - gid_z * params.slices.y; -)"; - code += " gid_z *= " + std::to_string(z_out) + ";\n"; - code += R"( - int gid_y0 = linear_xy / params.slices.x; - int gid_x0 = linear_xy - gid_y0 * params.slices.x; - - if (gid_z >= params.dst_size.w) return; -)"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " float4 r" + s_i + " = float4(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - code += R"( - device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; - - device FLT4* src_loc_0 = src_buffer + gid_y0 * params.src_size.x + gid_x0; - int s = 0; - do { - FLT4 src_0 = *src_loc_0; - src_loc_0 += params.src_size.z; -)"; - for (int i = 0; i < z_out * 4; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + - " += dot(tmp[" + s_i + "], src_0);\n"; - } - - code += " tmp += " + std::to_string(z_out * 4) + ";\n"; - code += R"( - s += 1; - } while (s < params.src_size.w); - const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; - - device FLT4* bias_loc = biases + gid_z; - )"; - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " r" + s_i + " += float4(bias_loc[" + s_i + "]);\n"; - } - for (int i = 0; i < z_out; ++i) { - const std::string s_i = std::to_string(i); - code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; - code += " FLT4 value = FLT4(r" + s_i + ");\n"; - code += - " int linear_index = offset_0 + params.dst_size.z * " + s_i + ";\n"; - code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - } - code += " }\n"; - return code; -} - -// Reorder weights to make the weights memory access pattern cache friendly for -// Convolution1x1/ConvolutionGeneric std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, int z_out) { const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - std::vector weights_reordered(params.weights.shape.w * - params.weights.shape.h * dst_depth * 4 * - src_depth * 4); + std::vector weights_reordered( + params.weights.shape.w * params.weights.shape.h * + AlignByN(dst_depth, z_out) * 4 * src_depth * 4); int counter = 0; for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) { for (int y = 0; y < params.weights.shape.h; ++y) { @@ -768,7 +519,7 @@ std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, dst_ch >= params.weights.shape.o) { weights_reordered[counter++] = 0.0f; } else { - const int f_index = + const size_t f_index = params.weights.shape.LinearIndex({dst_ch, y, x, src_ch}); weights_reordered[counter++] = params.weights.data[f_index]; } @@ -782,13 +533,12 @@ std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, return weights_reordered; } -uint3 GetWorkGroupForConv() { return {8, 4, 1}; } -uint3 GetWorkGroupForConvPrecise() { return {32, 1, 1}; } - -std::vector GetUniformBufferForConv( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - const int3 group_size = GetWorkGroupForConv(); +std::vector GetUniformBuffer(const BHWC& src_size, + const BHWC& dst_size, + const Convolution2DAttributes& attr, + const ConvParams& params) { + const int grid_x = IntegralDivideRoundUp(dst_size.w, params.block_size.x); + const int grid_y = IntegralDivideRoundUp(dst_size.h, params.block_size.y); std::vector uniform_params = { src_size.w, src_size.h, @@ -798,240 +548,280 @@ std::vector GetUniformBufferForConv( dst_size.h, dst_size.w * dst_size.h, IntegralDivideRoundUp(dst_size.c, 4), - params.strides.w, - params.strides.h, - -params.padding.prepended.w, - -params.padding.prepended.h, - params.weights.shape.w, - params.weights.shape.h, - params.dilations.w, - params.dilations.h, - group_size.x, - group_size.y, - group_size.z, - 1u, // dummy, for alignment + attr.strides.w, + attr.strides.h, + -attr.padding.prepended.w, + -attr.padding.prepended.h, + attr.weights.shape.w, + attr.weights.shape.h, + attr.dilations.w, + attr.dilations.h, + grid_x, + grid_x * grid_y, + 0, // dummy, for alignment + 0, // dummy, for alignment + params.work_group_size.x, + params.work_group_size.y, + params.work_group_size.z, + 0, // dummy, for alignment }; return GetByteBuffer(uniform_params); } -std::vector GetUniformBufferForConvPrecise( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - src_size.w, - src_size.h, - src_size.w * src_size.h, - IntegralDivideRoundUp(src_size.c, 4), - dst_size.w, - dst_size.h, - dst_size.w * dst_size.h, - IntegralDivideRoundUp(dst_size.c, 4), - params.strides.w, - params.strides.h, - -params.padding.prepended.w, - -params.padding.prepended.h, - params.weights.shape.w, - params.weights.shape.h, - params.dilations.w, - params.dilations.h, - dst_size.w, - IntegralDivideRoundUp(dst_size.w * dst_size.h, 2), - 0u, // dummy, for alignment - 0u, // dummy, for alignment - }; - return GetByteBuffer(uniform_params); +int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x, wg_size.x) * + IntegralDivideRoundUp(grid_y, wg_size.y) * + IntegralDivideRoundUp(grid_z, wg_size.z); } -std::vector GetUniformBufferForConvPrecise1x1( - const BHWC& src_size, const BHWC& dst_size, - const Convolution2DAttributes& params) { - std::vector uniform_params = { - src_size.w, - src_size.h, - src_size.w * src_size.h, - IntegralDivideRoundUp(src_size.c, 4), - dst_size.w, - dst_size.h, - dst_size.w * dst_size.h, - IntegralDivideRoundUp(dst_size.c, 4), - dst_size.w, - IntegralDivideRoundUp(dst_size.w * dst_size.h, 1), - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - 0u, // dummy, for alignment - }; - return GetByteBuffer(uniform_params); +int GetGroupsCountForLinearWH(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x * grid_y, wg_size.x) * + IntegralDivideRoundUp(grid_z, wg_size.y); } -uint3 GetGroupsCountForConv(const uint3& group_size, const BHWC& dst_shape) { - const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); - int groups_x = IntegralDivideRoundUp(dst_shape.w, group_size.x); - int groups_y = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_shape.h, 2), - group_size.y); - const int z_out = GetNumOutputSlices(dst_shape.c); - int groups_z = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_depth, z_out), - group_size.z); - return {groups_x, groups_y, groups_z}; +int GetGroupsCountForLinearWHS(const BHWC& dst_shape, const int3& wg_size, + const int3& block_size) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + + int grid_x = IntegralDivideRoundUp(dst_shape.w, block_size.x); + int grid_y = IntegralDivideRoundUp(dst_shape.h, block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, block_size.z); + + return IntegralDivideRoundUp(grid_x * grid_y * grid_z, wg_size.x); } -uint3 GetGroupsCountForConvPrecise(const uint3& group_size, - const BHWC& dst_shape, int xy_pixels) { - const int z_out = GetNumOutputSlices(dst_shape.c); - const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); - int xy_size = IntegralDivideRoundUp(dst_shape.w * dst_shape.h, xy_pixels); - int z_size = IntegralDivideRoundUp(dst_depth, z_out); - int task_size = xy_size * z_size; - return {IntegralDivideRoundUp(task_size, group_size.x), 1, 1}; -} - -int GetConvolutionThreadsCount(const BHWC& dst_shape) { - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, dst_shape); - return groups_count.x * groups_count.y * groups_count.z * group_size.x * - group_size.y * group_size.z; -} - -int GetConvolutionPreciseThreadsCount(const BHWC& dst_shape, int xy_pixels) { - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, dst_shape, xy_pixels); - return groups_count.x * groups_count.y * groups_count.z * group_size.x * - group_size.y * group_size.z; -} - -bool IsConv1x1(const Convolution2DAttributes& attr) { - return attr.weights.shape.h == 1 && attr.weights.shape.w == 1 && - attr.strides.h == 1 && attr.strides.w == 1 && attr.dilations.h == 1 && - attr.dilations.w == 1 && attr.padding.prepended.h == 0 && - attr.padding.prepended.w == 0 && attr.padding.appended.h == 0 && +bool IsKernelXIs1(const Convolution2DAttributes& attr) { + return attr.weights.shape.w == 1 && attr.strides.w == 1 && + attr.dilations.w == 1 && attr.padding.prepended.w == 0 && attr.padding.appended.w == 0; } +bool IsKernelYIs1(const Convolution2DAttributes& attr) { + return attr.weights.shape.h == 1 && attr.strides.h == 1 && + attr.dilations.h == 1 && attr.padding.prepended.h == 0 && + attr.padding.appended.h == 0; +} + +int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) { + if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { + return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1}); + } else { + return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1}); + } +} + +int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) { + const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu); + int base_threshold; + if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { + base_threshold = 32; + } else if (gpu == GpuType::kA11) { + base_threshold = 48; + } else { + base_threshold = 64; + } + if (max_waves >= base_threshold * 4) { + return 8; + } else if (max_waves >= base_threshold * 2) { + return 4; + } else if (max_waves >= base_threshold) { + return 2; + } else { + return 1; + } +} + +ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr, + const BHWC& dst_shape, GpuType gpu) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); + + ConvParams params; + params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + params.src_depth_loop_size = 1; + params.block_size = int3(1, 1, 1); + params.linear_wh = false; + params.linear_whs = false; + params.work_group_launch_order = int3(0, 1, 2); + + int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); + + if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) { + params.block_size.z = 4; + blk_total_size /= 4; + } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) { + params.block_size.z = 2; + blk_total_size /= 2; + } + if (blk_total_size >= 4) { + params.block_size.x = 2; + params.block_size.y = 2; + blk_total_size /= 4; + } else if (blk_total_size >= 2) { + if (dst_shape.w % 2 != 0 && dst_shape.h % 2 == 0) { + params.block_size.y = 2; + } else { + params.block_size.x = 2; + } + blk_total_size /= 2; + } + + params.work_group_size = params.block_size.x <= params.block_size.y + ? int3(8, 4, 1) + : int3(4, 8, 1); + + int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size); + int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, params.block_size); + int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, params.block_size); + + if (g2 < g1) { + params.linear_wh = true; + params.work_group_size = int3(32, 1, 1); + params.work_group_launch_order = int3(0, 1, 2); + } + float precise_threshold = 3.1f; + float precise_ratio = static_cast(g2) / static_cast(g3); + if (precise_ratio > precise_threshold) { + params.linear_wh = false; + params.linear_whs = true; + params.work_group_size = int3(32, 1, 1); + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + } + + if (params.src_depth_loop_size == src_slices) { + params.need_src_loop = false; + } + if (params.block_size.z == dst_slices) { + params.need_dst_loop = false; + } + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + if (use_filters_constants) { + params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; + } + + return params; +} + +ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr, + const BHWC& dst_shape, GpuType gpu) { + const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); + const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); + int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); + bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12; + int3 block_size = int3(1, 1, 1); + if (blk_total_size >= 2 && apple_gpu) { + if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) { + block_size.x = 2; + } else { + block_size.y = 2; + } + blk_total_size /= 2; + } + if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) { + block_size.z = 4; + blk_total_size /= 4; + } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) { + block_size.z = 2; + blk_total_size /= 2; + } + if (blk_total_size >= 4 && dst_slices == 3) { + block_size.z = 3; + blk_total_size /= 4; + } + + ConvParams params; + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + params.src_depth_loop_size = 1; + params.block_size = block_size; + params.linear_wh = false; + params.linear_whs = false; + params.work_group_size = int3(8, 4, 1); + params.work_group_launch_order = int3(2, 0, 1); + int g1 = GetGroupsCount(dst_shape, {8, 4, 1}, block_size); + int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, block_size); + int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, block_size); + if (g2 < g1) { + params.linear_wh = true; + params.work_group_size = int3(32, 1, 1); + params.work_group_launch_order = int3(0, 1, 2); + } + float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f; + float precise_ratio = static_cast(g2) / static_cast(g3); + if (precise_ratio > precise_threshold) { + params.linear_wh = false; + params.linear_whs = true; + params.work_group_size = int3(32, 1, 1); + } + int total_elements = + params.block_size.x * params.block_size.y * params.block_size.z; + if (total_elements == 1) { + if (src_slices % 4 == 0) { + params.src_depth_loop_size = 4; + } else if (src_slices % 2 == 0) { + params.src_depth_loop_size = 2; + } + } else if (total_elements == 2) { + if (src_slices % 2 == 0) { + params.src_depth_loop_size = 2; + } + } + if (params.src_depth_loop_size == src_slices) { + params.need_src_loop = false; + } + if (params.block_size.z == dst_slices) { + params.need_dst_loop = false; + } + const bool use_filters_constants = + !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 && + params.y_kernel_is_1; + if (use_filters_constants) { + params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; + } + + return params; +} + +ConvParams GetConvParams(const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + auto gpu_type = GetGpuType(); + if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) { + return GetConvParamsForA7A8(attr, dst_shape, gpu_type); + } else { + return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type); + } +} + } // namespace -std::vector Convolution( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - desc->shader_source = GetKernelForConv(params); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - return CalculateOutputShape(buffers.find(input_id)->second, params); - }}; - - auto weights_reordered = ReorderWeightsForConvShared(params); - desc->immutable_buffers = { - {"device FLT4* const weights", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvShared(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); - const uint3 group_size{8, 4, 1}; - int groups_x = IntegralDivideRoundUp(output_dims.w, group_size.x); - int groups_y = IntegralDivideRoundUp(output_dims.h, group_size.y); - const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); - int groups_z = IntegralDivideRoundUp(dst_depth, num_output_slices); - return std::make_pair(group_size, uint3{groups_x, groups_y, groups_z}); - }; - - return {desc}; -} - -std::vector Convolution1x1( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConv1x1(params, z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConv(input_dimensions, output_dimensions, - params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); - return std::make_pair( - group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); - }; - - return {desc}; -} - -bool CheckConvolution1x1Support(const Convolution2DAttributes& attr) { - return IsConv1x1(attr); -} - std::vector ConvolutionGeneric( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, + const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { + ConvParams params = GetConvParams(attr, dst_shape); + auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvGeneric(params, z_out); + desc->shader_source = GenerateConvolution(params); desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, @@ -1039,160 +829,72 @@ std::vector ConvolutionGeneric( desc->output_buffer = { output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { + [input_id, attr](const std::map& buffers) { auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); + CalculateOutputShape(buffers.find(input_id)->second, attr); return out_shape; }}; - auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights_reordered = ReorderWeightsForConv(attr, params.block_size.z); + std::string addr_space = + params.weights_upload_type == WeightsUploadType::CONSTANT_MEM ? "constant" + : "device"; + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); desc->immutable_buffers = { - {"device FLT4* const filters", + {addr_space + " FLT4* const filters", GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, + {addr_space + " FLT4* const biases", + GetByteBufferConvertedResized( + attr.bias.data, options.storage_precision, + AlignByN(dst_depth, params.block_size.z) * 4)}, }; desc->uniform_buffers = { {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConv(input_dimensions, output_dimensions, - params); + [input_id, output_id, attr, + params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + return GetUniformBuffer(src_shape, dst_shape, attr, params); }}, }; desc->resize_function = [output_id, params](const std::map& buffers) { const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConv(); - const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); - return std::make_pair( - group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); - }; + const int dst_slices = IntegralDivideRoundUp(output_dims.c, 4); - return {desc}; -} + int grid_x = IntegralDivideRoundUp(output_dims.w, params.block_size.x); + int grid_y = IntegralDivideRoundUp(output_dims.h, params.block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, params.block_size.z); -std::vector ConvolutionPrecise( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvPrecise(z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvPrecise(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, output_dims, 2); + const uint3 group_size(params.work_group_size.x, params.work_group_size.y, + params.work_group_size.z); + int3 wg; + uint3 groups_count; + if (params.linear_whs) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y * grid_z, + params.work_group_size.x); + groups_count = uint3(wg.x, 1, 1); + } else if (params.linear_wh) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_z, params.work_group_size.y); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], 1); + } else { + wg.x = IntegralDivideRoundUp(grid_x, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_y, params.work_group_size.y); + wg.z = IntegralDivideRoundUp(grid_z, params.work_group_size.z); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], + wg[params.work_group_launch_order.z]); + } return std::make_pair(group_size, groups_count); }; return {desc}; } -float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape) { - return static_cast(GetConvolutionThreadsCount(dst_shape)) / - static_cast(GetConvolutionPreciseThreadsCount(dst_shape, 2)); -} - -std::vector ConvolutionPrecise1x1PowerVR( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - const int z_out = GetNumOutputSlices(params.weights.shape.o); - desc->shader_source = GetKernelForConvPrecise1x1PowerVR(z_out); - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - auto out_shape = - CalculateOutputShape(buffers.find(input_id)->second, params); - return out_shape; - }}; - - auto weights_reordered = ReorderWeightsForConv(params, z_out); - desc->immutable_buffers = { - {"device FLT4* const filters", - GetByteBufferConverted(weights_reordered, options.storage_precision)}, - {"device FLT4* const biases", - GetByteBufferConvertedResized(params.bias.data, - options.storage_precision, - params.weights.shape.o)}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& input_dimensions = buffers.find(input_id)->second; - const auto& output_dimensions = buffers.find(output_id)->second; - return GetUniformBufferForConvPrecise1x1(input_dimensions, - output_dimensions, params); - }}, - }; - - desc->resize_function = [output_id, - params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const uint3 group_size = GetWorkGroupForConvPrecise(); - const uint3 groups_count = - GetGroupsCountForConvPrecise(group_size, output_dims, 1); - return std::make_pair(group_size, groups_count); - }; - - return {desc}; -} - -bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr) { - return IsConv1x1(attr); -} - } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h index 692145678cb..2853631abe8 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h @@ -27,67 +27,10 @@ namespace tflite { namespace gpu { namespace metal { -std::vector Convolution( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, - const metal::RuntimeOptions& options); - -// Convolution for kernel 1x1 -// require: -// kernel_size = 1x1; -// padding prepended and appended = 0x0 -// dilation = 1x1; -// stride = 1x1; -// Works very good on A12 (IPhoneXS, etc). -// Works good on A9/A10/A11 (IPhone6S, IPhone7, IPhoneX, etc). -// Works bad on A7/A8 (IPhone5S, IPhone6, etc). -std::vector Convolution1x1( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// TODO(impjdi): Move it inside module. -bool CheckConvolution1x1Support(const Convolution2DAttributes& attr); - -// This convolution pass all conv parameters (beside output_channels) -// as dynamic arguments (uniform buffer) to kernel. -// Depending on output_channels can be generated different kernels -// Kernel can proceed 4/8/12/16 output channels per one thread. -// 16 channels output is the fastest but the least flexible. std::vector ConvolutionGeneric( - int id, ValueId input_id, ValueId output_id, + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, const Convolution2DAttributes& params, const RuntimeOptions& options); -// This convolution makes more precise mapping of threads on elements. -// For example, if we have output tensor 12x7 and work group = 8x4, -// then we need 4 workgroups to cover this tensor in usual case. -// But in general we have only 84 elements(12*7), and we can cover it with 3 -// workgroups of size 32. So this version of convolution use this precise -// mapping. -// But this convolution, due to some hardware limitations, doesn't work better -// always. In general it works good on A12. -// Each thread process 2 pixels in XY dimension and variable amount of pixels -// in Z dimension(depends on dst_channels). -std::vector ConvolutionPrecise( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// As previous, but specific for 1x1 and each thread process 1 pixel in XY -// dimension. -// This convolution for PowerVR in FP16 mode with FP32 accumulator -// It will work in other modes also, but not with good performance -std::vector ConvolutionPrecise1x1PowerVR( - int id, ValueId input_id, ValueId output_id, - const Convolution2DAttributes& params, const RuntimeOptions& options); - -// TODO(impjdi): Move it inside module. -bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr); - -// This function calculates amount of threads that should be launched for -// ConvolutionGeneric or Convolution1x1 (threads_count1) and amount of threads -// that should be launched for ConvolutionPrecise (threads_count2) and returns -// threads_count1 / threads_count2. -float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape); - } // namespace metal } // namespace gpu } // namespace tflite