Added specialized concat for channels when all channels divisible by 4.
PiperOrigin-RevId: 305067398 Change-Id: Id30fe786a6b10a09c389ccc01494c7f87ff9b2d9
This commit is contained in:
parent
0755bc921a
commit
2557efa9ba
@ -33,79 +33,99 @@ namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
bool IsAllChannelsX4(const std::vector<int>& channels) {
|
||||
for (int channel : channels) {
|
||||
if (channel % 4 != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string GetConcatZCode(const std::vector<int> channels) {
|
||||
const std::string postfix[] = {".x", ".y", ".z", ".w"};
|
||||
const std::string postfix_2[] = {".x", ".xy", ".xyz", ""};
|
||||
const std::string types[] = {"FLT", "FLT2", "FLT3", "FLT4"};
|
||||
std::string code = R"(
|
||||
std::string c = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
struct uniforms {
|
||||
int4 src_size;
|
||||
int4 dst_size;
|
||||
};
|
||||
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint2 ugid[[thread_position_in_grid]]) {
|
||||
if (static_cast<int>(ugid.x) >= params.src_size.x ||
|
||||
static_cast<int>(ugid.y) >= params.src_size.y) {
|
||||
return;
|
||||
int X = static_cast<int>(ugid.x);
|
||||
int Y = static_cast<int>(ugid.y);
|
||||
int Z = 0;
|
||||
if (X >= U.dst_size.x || Y >= U.dst_size.y) return;
|
||||
|
||||
FLT4 value = FLT4(0.0f);
|
||||
const int xy_offset = Y * U.src_size.x + X;
|
||||
int linear_index = xy_offset;
|
||||
)";
|
||||
|
||||
if (IsAllChannelsX4(channels)) {
|
||||
// When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
|
||||
// Also it is easy to write a loop in this case, to prevent long kernel
|
||||
// generation.
|
||||
for (int i = 0; i < channels.size(); ++i) {
|
||||
const int depth = IntegralDivideRoundUp(channels[i], 4);
|
||||
const std::string src_buffer = "src_buffer" + std::to_string(i);
|
||||
c += " for (int i = 0; i < " + std::to_string(depth) + "; ++i) {\n";
|
||||
c += " int src_index = i * U.src_size.w + xy_offset;\n";
|
||||
c += " value = " + src_buffer + "[src_index];\n";
|
||||
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
|
||||
c += " $2\n";
|
||||
c += " dst_buffer[linear_index] = value;\n";
|
||||
c += " linear_index += U.src_size.w;\n";
|
||||
c += " Z++;\n";
|
||||
c += " }\n";
|
||||
}
|
||||
|
||||
FLT4 value = FLT4(0.0f);
|
||||
const int xy_offset = int(ugid.y) * params.src_size.x + int(ugid.x);
|
||||
int linear_index = xy_offset;
|
||||
)";
|
||||
|
||||
int out_channel = 0;
|
||||
int read_index = 0;
|
||||
int dst_z = 0;
|
||||
for (int i = 0; i < channels.size(); ++i) {
|
||||
const int depth = IntegralDivideRoundUp(channels[i], 4);
|
||||
code += " {\n";
|
||||
code += " int src_address = xy_offset;\n";
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
const int channels_in_group = std::min(4, channels[i] - d * 4);
|
||||
const std::string temp_name = "t" + std::to_string(read_index);
|
||||
code += " " + types[channels_in_group - 1] + " " + temp_name + " = " +
|
||||
"src_buffer" + std::to_string(i) + "[src_address]" +
|
||||
postfix_2[channels_in_group - 1] + ";\n";
|
||||
code += " src_address += params.src_size.w;\n";
|
||||
for (int c = 0; c < channels_in_group; ++c) {
|
||||
if (channels_in_group == 1) {
|
||||
code += " value" + postfix[out_channel] + " = " + temp_name + ";\n";
|
||||
} else {
|
||||
code += " value" + postfix[out_channel] + " = " + temp_name +
|
||||
postfix[c] += ";\n";
|
||||
}
|
||||
out_channel++;
|
||||
if (out_channel == 4) {
|
||||
out_channel = 0;
|
||||
code += " {\n";
|
||||
code += " uint3 gid = uint3(ugid.x, ugid.y, " +
|
||||
std::to_string(dst_z) + ");\n";
|
||||
code += " $2\n";
|
||||
code += " dst_buffer[linear_index] = value;\n";
|
||||
code += " linear_index += params.src_size.w;\n";
|
||||
code += " }\n";
|
||||
dst_z++;
|
||||
} else {
|
||||
int out_channel = 0;
|
||||
int read_index = 0;
|
||||
int z = 0;
|
||||
for (int i = 0; i < channels.size(); ++i) {
|
||||
const int depth = IntegralDivideRoundUp(channels[i], 4);
|
||||
const std::string src_buffer = "src_buffer" + std::to_string(i);
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
const int channels_in_group = std::min(4, channels[i] - d * 4);
|
||||
const std::string temp_name = "t" + std::to_string(read_index);
|
||||
const std::string src_index =
|
||||
std::to_string(d) + " * U.src_size.w + xy_offset";
|
||||
c += " FLT4 " + temp_name + " = " + src_buffer + "[" + src_index +
|
||||
"];\n";
|
||||
for (int ch = 0; ch < channels_in_group; ++ch) {
|
||||
c += " value" + postfix[out_channel] + " = ";
|
||||
c += temp_name + postfix[ch] + ";\n";
|
||||
out_channel++;
|
||||
if (out_channel == 4) {
|
||||
out_channel = 0;
|
||||
c += " {\n";
|
||||
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
|
||||
c += " $2\n";
|
||||
c += " dst_buffer[linear_index] = value;\n";
|
||||
c += " linear_index += U.src_size.w;\n";
|
||||
c += " Z++;\n";
|
||||
c += " }\n";
|
||||
z++;
|
||||
}
|
||||
}
|
||||
read_index++;
|
||||
}
|
||||
read_index++;
|
||||
}
|
||||
code += " }\n";
|
||||
if (out_channel != 0) {
|
||||
c += " {\n";
|
||||
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
|
||||
c += " $2\n";
|
||||
c += " dst_buffer[linear_index] = value;\n";
|
||||
c += " }\n";
|
||||
}
|
||||
}
|
||||
if (out_channel != 0) {
|
||||
code += " {\n";
|
||||
code += " uint3 gid = uint3(ugid.x, ugid.y, " + std::to_string(dst_z) +
|
||||
");\n";
|
||||
code += " $2\n";
|
||||
code += " dst_buffer[linear_index] = value;\n";
|
||||
code += " }\n";
|
||||
}
|
||||
code += "}\n";
|
||||
return code;
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -141,26 +161,33 @@ std::vector<ComputeTaskDescriptorPtr> ConcatZ(
|
||||
}};
|
||||
|
||||
desc->uniform_buffers = {
|
||||
{"constant uniforms& params",
|
||||
[input_ids](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& dimension = buffers.find(input_ids[0])->second;
|
||||
{"constant uniforms& U",
|
||||
[input_ids, output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& src_shape = buffers.find(input_ids[0])->second;
|
||||
const auto& dst_shape = buffers.find(output_id)->second;
|
||||
std::vector<int> uniform_params{
|
||||
dimension.w,
|
||||
dimension.h,
|
||||
0,
|
||||
dimension.w * dimension.h,
|
||||
src_shape.w,
|
||||
src_shape.h,
|
||||
IntegralDivideRoundUp(src_shape.c, 4),
|
||||
src_shape.w * src_shape.h,
|
||||
dst_shape.w,
|
||||
dst_shape.h,
|
||||
IntegralDivideRoundUp(dst_shape.c, 4),
|
||||
dst_shape.w * dst_shape.h,
|
||||
};
|
||||
return GetByteBuffer(uniform_params);
|
||||
}},
|
||||
};
|
||||
|
||||
desc->resize_function = [input_ids](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& src_dim = buffers.find(input_ids[0])->second;
|
||||
const uint3 groups_size{16, 16, 1};
|
||||
int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x);
|
||||
int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y);
|
||||
int groups_z = 1;
|
||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& dst_shape = buffers.find(output_id)->second;
|
||||
uint3 grid(dst_shape.w, dst_shape.h, 1);
|
||||
uint3 group_size{8u, 4u, 1u};
|
||||
uint3 groups;
|
||||
groups.x = IntegralDivideRoundUp(grid.x, group_size.x);
|
||||
groups.y = IntegralDivideRoundUp(grid.y, group_size.y);
|
||||
groups.z = IntegralDivideRoundUp(grid.z, group_size.z);
|
||||
return std::make_pair(group_size, groups);
|
||||
};
|
||||
|
||||
return {desc};
|
||||
@ -246,7 +273,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatX(
|
||||
|
||||
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& output_dims = buffers.find(output_id)->second;
|
||||
const uint3 groups_size{1, 1, 1};
|
||||
const uint3 groups_size{8, 4, 1};
|
||||
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
|
||||
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
|
||||
int groups_z = IntegralDivideRoundUp(output_dims.c, 4);
|
||||
@ -337,7 +364,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatY(
|
||||
|
||||
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||
const auto& output_dims = buffers.find(output_id)->second;
|
||||
const uint3 groups_size{1, 1, 1};
|
||||
const uint3 groups_size{8, 4, 1};
|
||||
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
|
||||
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
|
||||
int groups_z = IntegralDivideRoundUp(output_dims.c, 4);
|
||||
|
Loading…
Reference in New Issue
Block a user