Added specialized concat for channels when all channels divisible by 4.

PiperOrigin-RevId: 305067398
Change-Id: Id30fe786a6b10a09c389ccc01494c7f87ff9b2d9
This commit is contained in:
Raman Sarokin 2020-04-06 10:52:32 -07:00 committed by TensorFlower Gardener
parent 0755bc921a
commit 2557efa9ba

View File

@ -33,79 +33,99 @@ namespace gpu {
namespace metal {
namespace {
bool IsAllChannelsX4(const std::vector<int>& channels) {
for (int channel : channels) {
if (channel % 4 != 0) {
return false;
}
}
return true;
}
std::string GetConcatZCode(const std::vector<int> channels) {
const std::string postfix[] = {".x", ".y", ".z", ".w"};
const std::string postfix_2[] = {".x", ".xy", ".xyz", ""};
const std::string types[] = {"FLT", "FLT2", "FLT3", "FLT4"};
std::string code = R"(
std::string c = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
};
$0
kernel void ComputeFunction(
$1
uint2 ugid[[thread_position_in_grid]]) {
if (static_cast<int>(ugid.x) >= params.src_size.x ||
static_cast<int>(ugid.y) >= params.src_size.y) {
return;
int X = static_cast<int>(ugid.x);
int Y = static_cast<int>(ugid.y);
int Z = 0;
if (X >= U.dst_size.x || Y >= U.dst_size.y) return;
FLT4 value = FLT4(0.0f);
const int xy_offset = Y * U.src_size.x + X;
int linear_index = xy_offset;
)";
if (IsAllChannelsX4(channels)) {
// When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
// Also it is easy to write a loop in this case, to prevent long kernel
// generation.
for (int i = 0; i < channels.size(); ++i) {
const int depth = IntegralDivideRoundUp(channels[i], 4);
const std::string src_buffer = "src_buffer" + std::to_string(i);
c += " for (int i = 0; i < " + std::to_string(depth) + "; ++i) {\n";
c += " int src_index = i * U.src_size.w + xy_offset;\n";
c += " value = " + src_buffer + "[src_index];\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
c += " $2\n";
c += " dst_buffer[linear_index] = value;\n";
c += " linear_index += U.src_size.w;\n";
c += " Z++;\n";
c += " }\n";
}
FLT4 value = FLT4(0.0f);
const int xy_offset = int(ugid.y) * params.src_size.x + int(ugid.x);
int linear_index = xy_offset;
)";
int out_channel = 0;
int read_index = 0;
int dst_z = 0;
for (int i = 0; i < channels.size(); ++i) {
const int depth = IntegralDivideRoundUp(channels[i], 4);
code += " {\n";
code += " int src_address = xy_offset;\n";
for (int d = 0; d < depth; ++d) {
const int channels_in_group = std::min(4, channels[i] - d * 4);
const std::string temp_name = "t" + std::to_string(read_index);
code += " " + types[channels_in_group - 1] + " " + temp_name + " = " +
"src_buffer" + std::to_string(i) + "[src_address]" +
postfix_2[channels_in_group - 1] + ";\n";
code += " src_address += params.src_size.w;\n";
for (int c = 0; c < channels_in_group; ++c) {
if (channels_in_group == 1) {
code += " value" + postfix[out_channel] + " = " + temp_name + ";\n";
} else {
code += " value" + postfix[out_channel] + " = " + temp_name +
postfix[c] += ";\n";
}
out_channel++;
if (out_channel == 4) {
out_channel = 0;
code += " {\n";
code += " uint3 gid = uint3(ugid.x, ugid.y, " +
std::to_string(dst_z) + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " linear_index += params.src_size.w;\n";
code += " }\n";
dst_z++;
} else {
int out_channel = 0;
int read_index = 0;
int z = 0;
for (int i = 0; i < channels.size(); ++i) {
const int depth = IntegralDivideRoundUp(channels[i], 4);
const std::string src_buffer = "src_buffer" + std::to_string(i);
for (int d = 0; d < depth; ++d) {
const int channels_in_group = std::min(4, channels[i] - d * 4);
const std::string temp_name = "t" + std::to_string(read_index);
const std::string src_index =
std::to_string(d) + " * U.src_size.w + xy_offset";
c += " FLT4 " + temp_name + " = " + src_buffer + "[" + src_index +
"];\n";
for (int ch = 0; ch < channels_in_group; ++ch) {
c += " value" + postfix[out_channel] + " = ";
c += temp_name + postfix[ch] + ";\n";
out_channel++;
if (out_channel == 4) {
out_channel = 0;
c += " {\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
c += " $2\n";
c += " dst_buffer[linear_index] = value;\n";
c += " linear_index += U.src_size.w;\n";
c += " Z++;\n";
c += " }\n";
z++;
}
}
read_index++;
}
read_index++;
}
code += " }\n";
if (out_channel != 0) {
c += " {\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
c += " $2\n";
c += " dst_buffer[linear_index] = value;\n";
c += " }\n";
}
}
if (out_channel != 0) {
code += " {\n";
code += " uint3 gid = uint3(ugid.x, ugid.y, " + std::to_string(dst_z) +
");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
}
code += "}\n";
return code;
c += "}\n";
return c;
}
} // namespace
@ -141,26 +161,33 @@ std::vector<ComputeTaskDescriptorPtr> ConcatZ(
}};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_ids](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(input_ids[0])->second;
{"constant uniforms& U",
[input_ids, output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& src_shape = buffers.find(input_ids[0])->second;
const auto& dst_shape = buffers.find(output_id)->second;
std::vector<int> uniform_params{
dimension.w,
dimension.h,
0,
dimension.w * dimension.h,
src_shape.w,
src_shape.h,
IntegralDivideRoundUp(src_shape.c, 4),
src_shape.w * src_shape.h,
dst_shape.w,
dst_shape.h,
IntegralDivideRoundUp(dst_shape.c, 4),
dst_shape.w * dst_shape.h,
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [input_ids](const std::map<ValueId, BHWC>& buffers) {
const auto& src_dim = buffers.find(input_ids[0])->second;
const uint3 groups_size{16, 16, 1};
int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y);
int groups_z = 1;
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dst_shape = buffers.find(output_id)->second;
uint3 grid(dst_shape.w, dst_shape.h, 1);
uint3 group_size{8u, 4u, 1u};
uint3 groups;
groups.x = IntegralDivideRoundUp(grid.x, group_size.x);
groups.y = IntegralDivideRoundUp(grid.y, group_size.y);
groups.z = IntegralDivideRoundUp(grid.z, group_size.z);
return std::make_pair(group_size, groups);
};
return {desc};
@ -246,7 +273,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatX(
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 groups_size{1, 1, 1};
const uint3 groups_size{8, 4, 1};
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
int groups_z = IntegralDivideRoundUp(output_dims.c, 4);
@ -337,7 +364,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatY(
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 groups_size{1, 1, 1};
const uint3 groups_size{8, 4, 1};
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
int groups_z = IntegralDivideRoundUp(output_dims.c, 4);