Change MeanStddevNormalization's local_reduce to enforce 1D workgroups.
The OpenCL 2.0 work_group_reduce_add function, what local_reduce is reconstructing in OpenCL 1.x, only has this functionality as well. PiperOrigin-RevId: 327266948 Change-Id: I446a212c38d9e834aae1d63289ea6fd7f32986c0
This commit is contained in:
parent
cab58b3a0f
commit
7f2bc1e4b8
@ -32,7 +32,7 @@ std::string GetVectorReduceCode() {
|
|||||||
})";
|
})";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetReduceCode(size_t work_group_size_x, size_t work_group_size_y) {
|
std::string GetReduceCode() {
|
||||||
// If it is supported, use the built-in work_group_reduce_add function.
|
// If it is supported, use the built-in work_group_reduce_add function.
|
||||||
// Otherwise, implement a reduction using __local memory. Note this only works
|
// Otherwise, implement a reduction using __local memory. Note this only works
|
||||||
// with power-of-two work group sizes.
|
// with power-of-two work group sizes.
|
||||||
@ -45,22 +45,19 @@ std::string GetReduceCode(size_t work_group_size_x, size_t work_group_size_y) {
|
|||||||
#ifdef __opencl_c_work_group_collective_functions
|
#ifdef __opencl_c_work_group_collective_functions
|
||||||
#define local_reduce(input, tmp) work_group_reduce_add(input)
|
#define local_reduce(input, tmp) work_group_reduce_add(input)
|
||||||
#else // !defined(__opencl_c_work_group_collective_functions)
|
#else // !defined(__opencl_c_work_group_collective_functions)
|
||||||
static inline float local_reduce(float input, __local float tmp[)" +
|
static inline float local_reduce(float input, __local float* tmp) {
|
||||||
std::to_string(work_group_size_y) + "][" +
|
const size_t local_id = get_local_id(0);
|
||||||
std::to_string(work_group_size_x) + R"(]) {
|
tmp[local_id] = input;
|
||||||
const size_t local_id_x = get_local_id(0);
|
|
||||||
const size_t local_id_y = get_local_id(1);
|
|
||||||
tmp[local_id_y][local_id_x] = input;
|
|
||||||
mem_fence(CLK_LOCAL_MEM_FENCE);
|
mem_fence(CLK_LOCAL_MEM_FENCE);
|
||||||
size_t reduction_size = get_local_size(0) / 2;
|
size_t reduction_size = get_local_size(0) / 2;
|
||||||
while (reduction_size > 0) {
|
while (reduction_size > 0) {
|
||||||
if (local_id_x < reduction_size) {
|
if (local_id < reduction_size) {
|
||||||
tmp[local_id_y][local_id_x] += tmp[local_id_y][local_id_x + reduction_size];
|
tmp[local_id] += tmp[local_id + reduction_size];
|
||||||
}
|
}
|
||||||
mem_fence(CLK_LOCAL_MEM_FENCE);
|
mem_fence(CLK_LOCAL_MEM_FENCE);
|
||||||
reduction_size /= 2;
|
reduction_size /= 2;
|
||||||
}
|
}
|
||||||
return tmp[local_id_y][0];
|
return tmp[0];
|
||||||
}
|
}
|
||||||
#endif // defined(__opencl_c_work_group_collective_functions)
|
#endif // defined(__opencl_c_work_group_collective_functions)
|
||||||
)";
|
)";
|
||||||
@ -74,8 +71,8 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition)
|
|||||||
// that size to the kernel at runtime, and that is currently not supported.
|
// that size to the kernel at runtime, and that is currently not supported.
|
||||||
// For now, fix workgroup size to 128 threads.
|
// For now, fix workgroup size to 128 threads.
|
||||||
work_group_size_.x = 128;
|
work_group_size_.x = 128;
|
||||||
work_group_size_.y = 1;
|
work_group_size_.y = 1; // Required
|
||||||
work_group_size_.z = 1;
|
work_group_size_.z = 1; // Required
|
||||||
code_ = GetNormalizationCode();
|
code_ = GetNormalizationCode();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,16 +82,12 @@ std::string MeanStdDevNormalization::GetNormalizationCode() {
|
|||||||
|
|
||||||
std::string c = GetCommonDefines(definition_.precision);
|
std::string c = GetCommonDefines(definition_.precision);
|
||||||
c += GetVectorReduceCode();
|
c += GetVectorReduceCode();
|
||||||
c += GetReduceCode(work_group_size_.x, work_group_size_.y);
|
c += GetReduceCode();
|
||||||
c += "__attribute__((reqd_work_group_size(" +
|
c += "__attribute__((reqd_work_group_size(" +
|
||||||
std::to_string(work_group_size_.x) + ", " +
|
std::to_string(work_group_size_.x) + ", 1, 1)))\n";
|
||||||
std::to_string(work_group_size_.y) + ", " +
|
c += R"(__kernel void main_function($0) {
|
||||||
std::to_string(work_group_size_.z) + ")))\n";
|
|
||||||
c += R"(__kernel void main_function(
|
|
||||||
$0) {
|
|
||||||
#ifndef __opencl_c_work_group_collective_functions
|
#ifndef __opencl_c_work_group_collective_functions
|
||||||
__local float tmp[)" +
|
__local float tmp[)" +
|
||||||
std::to_string(work_group_size_.y) + "][" +
|
|
||||||
std::to_string(work_group_size_.x) + R"(];
|
std::to_string(work_group_size_.x) + R"(];
|
||||||
#endif
|
#endif
|
||||||
size_t B = get_global_id(1);
|
size_t B = get_global_id(1);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user