Use raw strings in FC kernel.
PiperOrigin-RevId: 332055748 Change-Id: I66daa3c2f2fa87f386cf191fa0891fd6f7137e12
This commit is contained in:
parent
1de0678b95
commit
221535f56d
@ -75,38 +75,38 @@ std::string FullyConnected::GetFullyConnectedKernelCode(
|
||||
break;
|
||||
}
|
||||
|
||||
const std::string wg_x = std::to_string(work_group_size.x);
|
||||
const std::string wg_y = std::to_string(work_group_size.y);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
c += " int gid = get_global_id(0);\n";
|
||||
c += " int2 tid = (int2)(get_local_id(0), get_local_id(1));\n";
|
||||
c += " ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);\n";
|
||||
c += " if (gid < args.dst_tensor.Slices()) {\n";
|
||||
c += " for (uint c = tid.y; c < args.src_tensor.Slices(); c += " + wg_y +
|
||||
") {\n";
|
||||
c += " FLT4 v = args.src_tensor.Read(0, 0, c);\n";
|
||||
c += " FLT16 w = args.weights.Read(c*args.dst_tensor.Slices() + gid);\n";
|
||||
c += " s.x += dot(v, w.s0123);\n";
|
||||
c += " s.y += dot(v, w.s4567);\n";
|
||||
c += " s.z += dot(v, w.s89ab);\n";
|
||||
c += " s.w += dot(v, w.scdef);\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " __local ACCUM_FLT4 temp[" + wg_x + "][" + wg_y + "];\n";
|
||||
c += " temp[tid.x][tid.y] = s;\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += " if (gid >= args.dst_tensor.Slices()) {\n";
|
||||
c += " return;\n";
|
||||
c += " }\n";
|
||||
c += " if (tid.y == 0) {\n";
|
||||
c += "#define WG_X " + std::to_string(work_group_size.x) + "\n";
|
||||
c += "#define WG_Y " + std::to_string(work_group_size.y) + "\n";
|
||||
|
||||
c += R"(__kernel void main_function($0) {
|
||||
int gid = get_global_id(0);
|
||||
int2 tid = (int2)(get_local_id(0), get_local_id(1));
|
||||
ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);
|
||||
if (gid < args.dst_tensor.Slices()) {
|
||||
for (uint c = tid.y; c < args.src_tensor.Slices(); c += WG_Y) {
|
||||
FLT4 v = args.src_tensor.Read(0, 0, c);
|
||||
FLT16 w = args.weights.Read(c*args.dst_tensor.Slices() + gid);
|
||||
s.x += dot(v, w.s0123);
|
||||
s.y += dot(v, w.s4567);
|
||||
s.z += dot(v, w.s89ab);
|
||||
s.w += dot(v, w.scdef);
|
||||
}
|
||||
}
|
||||
__local ACCUM_FLT4 temp[WG_X][WG_Y];
|
||||
temp[tid.x][tid.y] = s;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (gid >= args.dst_tensor.Slices()) {
|
||||
return;
|
||||
}
|
||||
if (tid.y == 0) {
|
||||
)";
|
||||
for (int i = 1; i < work_group_size.y; ++i) {
|
||||
c += " s += temp[tid.x][" + std::to_string(i) + "];\n";
|
||||
}
|
||||
c += " FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid);\n";
|
||||
c += " args.dst_tensor.Write(r0, 0, 0, gid);\n";
|
||||
c += " }\n";
|
||||
c += "}\n";
|
||||
c += R"( FLT4 r0 = TO_FLT4(s) + args.biases.Read(gid);
|
||||
args.dst_tensor.Write(r0, 0, 0, gid);
|
||||
}
|
||||
})";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user