Change MeanStddevNormalization's local_reduce to enforce 1D workgroups.

The OpenCL 2.0 work_group_reduce_add function, what local_reduce is reconstructing in OpenCL 1.x, only has this functionality as well.

PiperOrigin-RevId: 327266948
Change-Id: I446a212c38d9e834aae1d63289ea6fd7f32986c0
This commit is contained in:
Robert David 2020-08-18 11:14:33 -07:00 committed by TensorFlower Gardener
parent cab58b3a0f
commit 7f2bc1e4b8

View File

@ -32,7 +32,7 @@ std::string GetVectorReduceCode() {
})"; })";
} }
std::string GetReduceCode(size_t work_group_size_x, size_t work_group_size_y) { std::string GetReduceCode() {
// If it is supported, use the built-in work_group_reduce_add function. // If it is supported, use the built-in work_group_reduce_add function.
// Otherwise, implement a reduction using __local memory. Note this only works // Otherwise, implement a reduction using __local memory. Note this only works
// with power-of-two work group sizes. // with power-of-two work group sizes.
@ -45,22 +45,19 @@ std::string GetReduceCode(size_t work_group_size_x, size_t work_group_size_y) {
#ifdef __opencl_c_work_group_collective_functions #ifdef __opencl_c_work_group_collective_functions
#define local_reduce(input, tmp) work_group_reduce_add(input) #define local_reduce(input, tmp) work_group_reduce_add(input)
#else // !defined(__opencl_c_work_group_collective_functions) #else // !defined(__opencl_c_work_group_collective_functions)
static inline float local_reduce(float input, __local float tmp[)" + static inline float local_reduce(float input, __local float* tmp) {
std::to_string(work_group_size_y) + "][" + const size_t local_id = get_local_id(0);
std::to_string(work_group_size_x) + R"(]) { tmp[local_id] = input;
const size_t local_id_x = get_local_id(0);
const size_t local_id_y = get_local_id(1);
tmp[local_id_y][local_id_x] = input;
mem_fence(CLK_LOCAL_MEM_FENCE); mem_fence(CLK_LOCAL_MEM_FENCE);
size_t reduction_size = get_local_size(0) / 2; size_t reduction_size = get_local_size(0) / 2;
while (reduction_size > 0) { while (reduction_size > 0) {
if (local_id_x < reduction_size) { if (local_id < reduction_size) {
tmp[local_id_y][local_id_x] += tmp[local_id_y][local_id_x + reduction_size]; tmp[local_id] += tmp[local_id + reduction_size];
} }
mem_fence(CLK_LOCAL_MEM_FENCE); mem_fence(CLK_LOCAL_MEM_FENCE);
reduction_size /= 2; reduction_size /= 2;
} }
return tmp[local_id_y][0]; return tmp[0];
} }
#endif // defined(__opencl_c_work_group_collective_functions) #endif // defined(__opencl_c_work_group_collective_functions)
)"; )";
@ -74,8 +71,8 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition)
// that size to the kernel at runtime, and that is currently not supported. // that size to the kernel at runtime, and that is currently not supported.
// For now, fix workgroup size to 128 threads. // For now, fix workgroup size to 128 threads.
work_group_size_.x = 128; work_group_size_.x = 128;
work_group_size_.y = 1; work_group_size_.y = 1; // Required
work_group_size_.z = 1; work_group_size_.z = 1; // Required
code_ = GetNormalizationCode(); code_ = GetNormalizationCode();
} }
@ -85,16 +82,12 @@ std::string MeanStdDevNormalization::GetNormalizationCode() {
std::string c = GetCommonDefines(definition_.precision); std::string c = GetCommonDefines(definition_.precision);
c += GetVectorReduceCode(); c += GetVectorReduceCode();
c += GetReduceCode(work_group_size_.x, work_group_size_.y); c += GetReduceCode();
c += "__attribute__((reqd_work_group_size(" + c += "__attribute__((reqd_work_group_size(" +
std::to_string(work_group_size_.x) + ", " + std::to_string(work_group_size_.x) + ", 1, 1)))\n";
std::to_string(work_group_size_.y) + ", " + c += R"(__kernel void main_function($0) {
std::to_string(work_group_size_.z) + ")))\n";
c += R"(__kernel void main_function(
$0) {
#ifndef __opencl_c_work_group_collective_functions #ifndef __opencl_c_work_group_collective_functions
__local float tmp[)" + __local float tmp[)" +
std::to_string(work_group_size_.y) + "][" +
std::to_string(work_group_size_.x) + R"(]; std::to_string(work_group_size_.x) + R"(];
#endif #endif
size_t B = get_global_id(1); size_t B = get_global_id(1);