Use raw strings in FC kernel.

PiperOrigin-RevId: 332055748
Change-Id: I66daa3c2f2fa87f386cf191fa0891fd6f7137e12
This commit is contained in:
Robert David 2020-09-16 11:56:43 -07:00 committed by TensorFlower Gardener
parent 1de0678b95
commit 221535f56d

View File

@ -75,38 +75,38 @@ std::string FullyConnected::GetFullyConnectedKernelCode(
break;
}
const std::string wg_x = std::to_string(work_group_size.x);
const std::string wg_y = std::to_string(work_group_size.y);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int gid = get_global_id(0);\n";
c += " int2 tid = (int2)(get_local_id(0), get_local_id(1));\n";
c += " ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);\n";
c += " if (gid < args.dst_tensor.Slices()) {\n";
c += " for (uint c = tid.y; c < args.src_tensor.Slices(); c += " + wg_y +
") {\n";
c += " FLT4 v = args.src_tensor.Read(0, 0, c);\n";
c += " FLT16 w = args.weights.Read(c*args.dst_tensor.Slices() + gid);\n";
c += " s.x += dot(v, w.s0123);\n";
c += " s.y += dot(v, w.s4567);\n";
c += " s.z += dot(v, w.s89ab);\n";
c += " s.w += dot(v, w.scdef);\n";
c += " }\n";
c += " }\n";
c += " __local ACCUM_FLT4 temp[" + wg_x + "][" + wg_y + "];\n";
c += " temp[tid.x][tid.y] = s;\n";
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
c += " if (gid >= args.dst_tensor.Slices()) {\n";
c += " return;\n";
c += " }\n";
c += " if (tid.y == 0) {\n";
c += "#define WG_X " + std::to_string(work_group_size.x) + "\n";
c += "#define WG_Y " + std::to_string(work_group_size.y) + "\n";
c += R"(__kernel void main_function($0) {
int gid = get_global_id(0);
int2 tid = (int2)(get_local_id(0), get_local_id(1));
ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);
if (gid < args.dst_tensor.Slices()) {
for (uint c = tid.y; c < args.src_tensor.Slices(); c += WG_Y) {
FLT4 v = args.src_tensor.Read(0, 0, c);
FLT16 w = args.weights.Read(c*args.dst_tensor.Slices() + gid);
s.x += dot(v, w.s0123);
s.y += dot(v, w.s4567);
s.z += dot(v, w.s89ab);
s.w += dot(v, w.scdef);
}
}
__local ACCUM_FLT4 temp[WG_X][WG_Y];
temp[tid.x][tid.y] = s;
barrier(CLK_LOCAL_MEM_FENCE);
if (gid >= args.dst_tensor.Slices()) {
return;
}
if (tid.y == 0) {
)";
for (int i = 1; i < work_group_size.y; ++i) {
c += " s += temp[tid.x][" + std::to_string(i) + "];\n";
}
c += " FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid);\n";
c += " args.dst_tensor.Write(r0, 0, 0, gid);\n";
c += " }\n";
c += "}\n";
c += R"( FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid);
args.dst_tensor.Write(r0, 0, 0, gid);
}
})";
return c;
}