diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc b/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc index 6fe1a17d6d2..c252ee0b348 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc @@ -33,79 +33,99 @@ namespace gpu { namespace metal { namespace { +bool IsAllChannelsX4(const std::vector& channels) { + for (int channel : channels) { + if (channel % 4 != 0) { + return false; + } + } + return true; +} + std::string GetConcatZCode(const std::vector channels) { const std::string postfix[] = {".x", ".y", ".z", ".w"}; - const std::string postfix_2[] = {".x", ".xy", ".xyz", ""}; - const std::string types[] = {"FLT", "FLT2", "FLT3", "FLT4"}; - std::string code = R"( + std::string c = R"( #include using namespace metal; struct uniforms { int4 src_size; + int4 dst_size; }; $0 kernel void ComputeFunction( $1 uint2 ugid[[thread_position_in_grid]]) { - if (static_cast(ugid.x) >= params.src_size.x || - static_cast(ugid.y) >= params.src_size.y) { - return; + int X = static_cast(ugid.x); + int Y = static_cast(ugid.y); + int Z = 0; + if (X >= U.dst_size.x || Y >= U.dst_size.y) return; + + FLT4 value = FLT4(0.0f); + const int xy_offset = Y * U.src_size.x + X; + int linear_index = xy_offset; +)"; + + if (IsAllChannelsX4(channels)) { + // When all channels % 4 == 0 we can read/assign/write FLT4 elements easily. + // Also it is easy to write a loop in this case, to prevent long kernel + // generation. + for (int i = 0; i < channels.size(); ++i) { + const int depth = IntegralDivideRoundUp(channels[i], 4); + const std::string src_buffer = "src_buffer" + std::to_string(i); + c += " for (int i = 0; i < " + std::to_string(depth) + "; ++i) {\n"; + c += " int src_index = i * U.src_size.w + xy_offset;\n"; + c += " value = " + src_buffer + "[src_index];\n"; + c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n"; + c += " $2\n"; + c += " dst_buffer[linear_index] = value;\n"; + c += " linear_index += U.src_size.w;\n"; + c += " Z++;\n"; + c += " }\n"; } - - FLT4 value = FLT4(0.0f); - const int xy_offset = int(ugid.y) * params.src_size.x + int(ugid.x); - int linear_index = xy_offset; - )"; - - int out_channel = 0; - int read_index = 0; - int dst_z = 0; - for (int i = 0; i < channels.size(); ++i) { - const int depth = IntegralDivideRoundUp(channels[i], 4); - code += " {\n"; - code += " int src_address = xy_offset;\n"; - for (int d = 0; d < depth; ++d) { - const int channels_in_group = std::min(4, channels[i] - d * 4); - const std::string temp_name = "t" + std::to_string(read_index); - code += " " + types[channels_in_group - 1] + " " + temp_name + " = " + - "src_buffer" + std::to_string(i) + "[src_address]" + - postfix_2[channels_in_group - 1] + ";\n"; - code += " src_address += params.src_size.w;\n"; - for (int c = 0; c < channels_in_group; ++c) { - if (channels_in_group == 1) { - code += " value" + postfix[out_channel] + " = " + temp_name + ";\n"; - } else { - code += " value" + postfix[out_channel] + " = " + temp_name + - postfix[c] += ";\n"; - } - out_channel++; - if (out_channel == 4) { - out_channel = 0; - code += " {\n"; - code += " uint3 gid = uint3(ugid.x, ugid.y, " + - std::to_string(dst_z) + ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " linear_index += params.src_size.w;\n"; - code += " }\n"; - dst_z++; + } else { + int out_channel = 0; + int read_index = 0; + int z = 0; + for (int i = 0; i < channels.size(); ++i) { + const int depth = IntegralDivideRoundUp(channels[i], 4); + const std::string src_buffer = "src_buffer" + std::to_string(i); + for (int d = 0; d < depth; ++d) { + const int channels_in_group = std::min(4, channels[i] - d * 4); + const std::string temp_name = "t" + std::to_string(read_index); + const std::string src_index = + std::to_string(d) + " * U.src_size.w + xy_offset"; + c += " FLT4 " + temp_name + " = " + src_buffer + "[" + src_index + + "];\n"; + for (int ch = 0; ch < channels_in_group; ++ch) { + c += " value" + postfix[out_channel] + " = "; + c += temp_name + postfix[ch] + ";\n"; + out_channel++; + if (out_channel == 4) { + out_channel = 0; + c += " {\n"; + c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n"; + c += " $2\n"; + c += " dst_buffer[linear_index] = value;\n"; + c += " linear_index += U.src_size.w;\n"; + c += " Z++;\n"; + c += " }\n"; + z++; + } } + read_index++; } - read_index++; } - code += " }\n"; + if (out_channel != 0) { + c += " {\n"; + c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n"; + c += " $2\n"; + c += " dst_buffer[linear_index] = value;\n"; + c += " }\n"; + } } - if (out_channel != 0) { - code += " {\n"; - code += " uint3 gid = uint3(ugid.x, ugid.y, " + std::to_string(dst_z) + - ");\n"; - code += " $2\n"; - code += " dst_buffer[linear_index] = value;\n"; - code += " }\n"; - } - code += "}\n"; - return code; + c += "}\n"; + return c; } } // namespace @@ -141,26 +161,33 @@ std::vector ConcatZ( }}; desc->uniform_buffers = { - {"constant uniforms& params", - [input_ids](const std::map& buffers) { - const auto& dimension = buffers.find(input_ids[0])->second; + {"constant uniforms& U", + [input_ids, output_id](const std::map& buffers) { + const auto& src_shape = buffers.find(input_ids[0])->second; + const auto& dst_shape = buffers.find(output_id)->second; std::vector uniform_params{ - dimension.w, - dimension.h, - 0, - dimension.w * dimension.h, + src_shape.w, + src_shape.h, + IntegralDivideRoundUp(src_shape.c, 4), + src_shape.w * src_shape.h, + dst_shape.w, + dst_shape.h, + IntegralDivideRoundUp(dst_shape.c, 4), + dst_shape.w * dst_shape.h, }; return GetByteBuffer(uniform_params); }}, }; - desc->resize_function = [input_ids](const std::map& buffers) { - const auto& src_dim = buffers.find(input_ids[0])->second; - const uint3 groups_size{16, 16, 1}; - int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); - int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); - int groups_z = 1; - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + desc->resize_function = [output_id](const std::map& buffers) { + const auto& dst_shape = buffers.find(output_id)->second; + uint3 grid(dst_shape.w, dst_shape.h, 1); + uint3 group_size{8u, 4u, 1u}; + uint3 groups; + groups.x = IntegralDivideRoundUp(grid.x, group_size.x); + groups.y = IntegralDivideRoundUp(grid.y, group_size.y); + groups.z = IntegralDivideRoundUp(grid.z, group_size.z); + return std::make_pair(group_size, groups); }; return {desc}; @@ -246,7 +273,7 @@ std::vector ConcatX( desc->resize_function = [output_id](const std::map& buffers) { const auto& output_dims = buffers.find(output_id)->second; - const uint3 groups_size{1, 1, 1}; + const uint3 groups_size{8, 4, 1}; int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); int groups_z = IntegralDivideRoundUp(output_dims.c, 4); @@ -337,7 +364,7 @@ std::vector ConcatY( desc->resize_function = [output_id](const std::map& buffers) { const auto& output_dims = buffers.find(output_id)->second; - const uint3 groups_size{1, 1, 1}; + const uint3 groups_size{8, 4, 1}; int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); int groups_z = IntegralDivideRoundUp(output_dims.c, 4);