Added specialized concat for channels when all channels divisible by 4.

PiperOrigin-RevId: 305067398
Change-Id: Id30fe786a6b10a09c389ccc01494c7f87ff9b2d9
This commit is contained in:
Raman Sarokin 2020-04-06 10:52:32 -07:00 committed by TensorFlower Gardener
parent 0755bc921a
commit 2557efa9ba

View File

@ -33,79 +33,99 @@ namespace gpu {
namespace metal { namespace metal {
namespace { namespace {
bool IsAllChannelsX4(const std::vector<int>& channels) {
for (int channel : channels) {
if (channel % 4 != 0) {
return false;
}
}
return true;
}
std::string GetConcatZCode(const std::vector<int> channels) { std::string GetConcatZCode(const std::vector<int> channels) {
const std::string postfix[] = {".x", ".y", ".z", ".w"}; const std::string postfix[] = {".x", ".y", ".z", ".w"};
const std::string postfix_2[] = {".x", ".xy", ".xyz", ""}; std::string c = R"(
const std::string types[] = {"FLT", "FLT2", "FLT3", "FLT4"};
std::string code = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
struct uniforms { struct uniforms {
int4 src_size; int4 src_size;
int4 dst_size;
}; };
$0 $0
kernel void ComputeFunction( kernel void ComputeFunction(
$1 $1
uint2 ugid[[thread_position_in_grid]]) { uint2 ugid[[thread_position_in_grid]]) {
if (static_cast<int>(ugid.x) >= params.src_size.x || int X = static_cast<int>(ugid.x);
static_cast<int>(ugid.y) >= params.src_size.y) { int Y = static_cast<int>(ugid.y);
return; int Z = 0;
} if (X >= U.dst_size.x || Y >= U.dst_size.y) return;
FLT4 value = FLT4(0.0f); FLT4 value = FLT4(0.0f);
const int xy_offset = int(ugid.y) * params.src_size.x + int(ugid.x); const int xy_offset = Y * U.src_size.x + X;
int linear_index = xy_offset; int linear_index = xy_offset;
)"; )";
int out_channel = 0; if (IsAllChannelsX4(channels)) {
int read_index = 0; // When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
int dst_z = 0; // Also it is easy to write a loop in this case, to prevent long kernel
// generation.
for (int i = 0; i < channels.size(); ++i) { for (int i = 0; i < channels.size(); ++i) {
const int depth = IntegralDivideRoundUp(channels[i], 4); const int depth = IntegralDivideRoundUp(channels[i], 4);
code += " {\n"; const std::string src_buffer = "src_buffer" + std::to_string(i);
code += " int src_address = xy_offset;\n"; c += " for (int i = 0; i < " + std::to_string(depth) + "; ++i) {\n";
c += " int src_index = i * U.src_size.w + xy_offset;\n";
c += " value = " + src_buffer + "[src_index];\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
c += " $2\n";
c += " dst_buffer[linear_index] = value;\n";
c += " linear_index += U.src_size.w;\n";
c += " Z++;\n";
c += " }\n";
}
} else {
int out_channel = 0;
int read_index = 0;
int z = 0;
for (int i = 0; i < channels.size(); ++i) {
const int depth = IntegralDivideRoundUp(channels[i], 4);
const std::string src_buffer = "src_buffer" + std::to_string(i);
for (int d = 0; d < depth; ++d) { for (int d = 0; d < depth; ++d) {
const int channels_in_group = std::min(4, channels[i] - d * 4); const int channels_in_group = std::min(4, channels[i] - d * 4);
const std::string temp_name = "t" + std::to_string(read_index); const std::string temp_name = "t" + std::to_string(read_index);
code += " " + types[channels_in_group - 1] + " " + temp_name + " = " + const std::string src_index =
"src_buffer" + std::to_string(i) + "[src_address]" + std::to_string(d) + " * U.src_size.w + xy_offset";
postfix_2[channels_in_group - 1] + ";\n"; c += " FLT4 " + temp_name + " = " + src_buffer + "[" + src_index +
code += " src_address += params.src_size.w;\n"; "];\n";
for (int c = 0; c < channels_in_group; ++c) { for (int ch = 0; ch < channels_in_group; ++ch) {
if (channels_in_group == 1) { c += " value" + postfix[out_channel] + " = ";
code += " value" + postfix[out_channel] + " = " + temp_name + ";\n"; c += temp_name + postfix[ch] + ";\n";
} else {
code += " value" + postfix[out_channel] + " = " + temp_name +
postfix[c] += ";\n";
}
out_channel++; out_channel++;
if (out_channel == 4) { if (out_channel == 4) {
out_channel = 0; out_channel = 0;
code += " {\n"; c += " {\n";
code += " uint3 gid = uint3(ugid.x, ugid.y, " + c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
std::to_string(dst_z) + ");\n"; c += " $2\n";
code += " $2\n"; c += " dst_buffer[linear_index] = value;\n";
code += " dst_buffer[linear_index] = value;\n"; c += " linear_index += U.src_size.w;\n";
code += " linear_index += params.src_size.w;\n"; c += " Z++;\n";
code += " }\n"; c += " }\n";
dst_z++; z++;
} }
} }
read_index++; read_index++;
} }
code += " }\n";
} }
if (out_channel != 0) { if (out_channel != 0) {
code += " {\n"; c += " {\n";
code += " uint3 gid = uint3(ugid.x, ugid.y, " + std::to_string(dst_z) + c += " uint3 gid = uint3(ugid.x, ugid.y, uint(Z));\n";
");\n"; c += " $2\n";
code += " $2\n"; c += " dst_buffer[linear_index] = value;\n";
code += " dst_buffer[linear_index] = value;\n"; c += " }\n";
code += " }\n";
} }
code += "}\n"; }
return code; c += "}\n";
return c;
} }
} // namespace } // namespace
@ -141,26 +161,33 @@ std::vector<ComputeTaskDescriptorPtr> ConcatZ(
}}; }};
desc->uniform_buffers = { desc->uniform_buffers = {
{"constant uniforms& params", {"constant uniforms& U",
[input_ids](const std::map<ValueId, BHWC>& buffers) { [input_ids, output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(input_ids[0])->second; const auto& src_shape = buffers.find(input_ids[0])->second;
const auto& dst_shape = buffers.find(output_id)->second;
std::vector<int> uniform_params{ std::vector<int> uniform_params{
dimension.w, src_shape.w,
dimension.h, src_shape.h,
0, IntegralDivideRoundUp(src_shape.c, 4),
dimension.w * dimension.h, src_shape.w * src_shape.h,
dst_shape.w,
dst_shape.h,
IntegralDivideRoundUp(dst_shape.c, 4),
dst_shape.w * dst_shape.h,
}; };
return GetByteBuffer(uniform_params); return GetByteBuffer(uniform_params);
}}, }},
}; };
desc->resize_function = [input_ids](const std::map<ValueId, BHWC>& buffers) { desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& src_dim = buffers.find(input_ids[0])->second; const auto& dst_shape = buffers.find(output_id)->second;
const uint3 groups_size{16, 16, 1}; uint3 grid(dst_shape.w, dst_shape.h, 1);
int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); uint3 group_size{8u, 4u, 1u};
int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); uint3 groups;
int groups_z = 1; groups.x = IntegralDivideRoundUp(grid.x, group_size.x);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); groups.y = IntegralDivideRoundUp(grid.y, group_size.y);
groups.z = IntegralDivideRoundUp(grid.z, group_size.z);
return std::make_pair(group_size, groups);
}; };
return {desc}; return {desc};
@ -246,7 +273,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatX(
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) { desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second; const auto& output_dims = buffers.find(output_id)->second;
const uint3 groups_size{1, 1, 1}; const uint3 groups_size{8, 4, 1};
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
int groups_z = IntegralDivideRoundUp(output_dims.c, 4); int groups_z = IntegralDivideRoundUp(output_dims.c, 4);
@ -337,7 +364,7 @@ std::vector<ComputeTaskDescriptorPtr> ConcatY(
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) { desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second; const auto& output_dims = buffers.find(output_id)->second;
const uint3 groups_size{1, 1, 1}; const uint3 groups_size{8, 4, 1};
int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y);
int groups_z = IntegralDivideRoundUp(output_dims.c, 4); int groups_z = IntegralDivideRoundUp(output_dims.c, 4);