Some fixes for private mem broadcast convolution.

PiperOrigin-RevId: 316766133
Change-Id: Icf2273be31c299c49664b9b9fe17df55a0c53693
This commit is contained in:
Raman Sarokin 2020-06-16 15:06:59 -07:00 committed by TensorFlower Gardener
parent d4c2030375
commit ef68eff537
2 changed files with 78 additions and 42 deletions

View File

@ -196,6 +196,9 @@ absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) {
creation_context.device->IsPowerVR()) {
options.push_back(CompilerOptions::POWERVR_FP16);
}
if (conv_params_.IsPrivateMemBroadcast()) {
options.push_back(CompilerOptions::CL_2_0);
}
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", options, *creation_context.context,
*creation_context.device, &kernel_);
@ -311,37 +314,10 @@ std::string GenerateConv(
const int local_mem_size =
conv_params.block_size.z * 4 * conv_params.src_depth_loop_size;
const bool use_simd_broadcast =
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast();
const int simd_size = conv_params.GetSimdSize();
int simd_size = 1;
if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
simd_size = 8;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
simd_size = 16;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
simd_size = 32;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
simd_size = 64;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
simd_size = 128;
}
bool late_oob_check = need_local_mem || use_simd_broadcast;
const bool late_oob_check = need_local_mem || use_simd_broadcast;
const std::string weights_space =
conv_params.weights_upload_type ==
@ -355,6 +331,12 @@ std::string GenerateConv(
const std::string weights_global_ptr =
weights_space + " " + weights_data_type + "*";
if (use_simd_broadcast) {
if (device.cl_version() == OpenCLVersion::CL_2_0) {
c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
}
}
const int3 work_group_size = conv_params.work_group_size;
const int3 block_size = conv_params.block_size;
if (conv_params.fixed_work_group_size) {
@ -364,7 +346,7 @@ std::string GenerateConv(
std::to_string(work_group_size.z) + ")))\n";
}
if (use_simd_broadcast && device.IsIntel()) {
c += "__attribute__((intel_reqd_work_group_size(" +
c += "__attribute__((intel_reqd_sub_group_size(" +
std::to_string(simd_size) + ")))\n";
}
c += "__kernel void main_function(\n";
@ -408,6 +390,9 @@ std::string GenerateConv(
std::to_string(work_group_size.x) + " + get_local_id(0);\n";
}
}
if (use_simd_broadcast) {
c += " int simd_id = get_sub_group_local_id();\n";
}
for (int z = 0; z < block_size.z; ++z) {
for (int y = 0; y < block_size.y; ++y) {
for (int x = 0; x < block_size.x; ++x) {
@ -555,17 +540,36 @@ std::string GenerateConv(
for (int y = 0; y < block_size.y; ++y) {
for (int x = 0; x < block_size.x; ++x) {
std::string id = std::to_string(y) + std::to_string(x);
std::string w_val = "weights_cache[" +
std::to_string(z * 4 + ch + shared_offset) +
"]";
if (use_simd_broadcast) {
int simd_id = (z * 4 + ch + shared_offset) / simd_size;
int thread_id = (z * 4 + ch + shared_offset) % simd_size;
w_val = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
", " + std::to_string(thread_id) + "u)";
std::string w_val_x = "sub_group_broadcast(simd_w" +
std::to_string(simd_id) + ".x, " +
std::to_string(thread_id) + "u)";
std::string w_val_y = "sub_group_broadcast(simd_w" +
std::to_string(simd_id) + ".y, " +
std::to_string(thread_id) + "u)";
std::string w_val_z = "sub_group_broadcast(simd_w" +
std::to_string(simd_id) + ".z, " +
std::to_string(thread_id) + "u)";
std::string w_val_w = "sub_group_broadcast(simd_w" +
std::to_string(simd_id) + ".w, " +
std::to_string(thread_id) + "u)";
c += " r" + std::to_string(z) + id + ".x += " + w_val_x +
" * src" + id + "." + channels[ch] + ";\n";
c += " r" + std::to_string(z) + id + ".y += " + w_val_y +
" * src" + id + "." + channels[ch] + ";\n";
c += " r" + std::to_string(z) + id + ".z += " + w_val_z +
" * src" + id + "." + channels[ch] + ";\n";
c += " r" + std::to_string(z) + id + ".w += " + w_val_w +
" * src" + id + "." + channels[ch] + ";\n";
} else {
std::string w_val = "weights_cache[" +
std::to_string(z * 4 + ch + shared_offset) +
"]";
c += " r" + std::to_string(z) + id + " += " + w_val +
" * src" + id + "." + channels[ch] + ";\n";
}
c += " r" + std::to_string(z) + id + " += " + w_val +
" * src" + id + "." + channels[ch] + ";\n";
}
}
}
@ -608,16 +612,15 @@ std::string GenerateConv(
int parts = local_mem_size / simd_size;
int reminder = local_mem_size % simd_size;
for (int i = 0; i < parts; ++i) {
c += " FLT4 simd_w" + std::to_string(i) +
" = filters_loc[get_sub_group_local_id() + " +
c += " FLT4 simd_w" + std::to_string(i) + " = filters_loc[simd_id + " +
std::to_string(i * simd_size) + "];\n";
}
if (reminder) {
c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
c += " simd_w" + std::to_string(parts) +
" = filters_loc[get_sub_group_local_id() + " +
std::to_string(parts * simd_size) + "];\n";
" = filters_loc[simd_id + " + std::to_string(parts * simd_size) +
"];\n";
c += " }\n";
}
} else { // GLOBAL_MEM/CONSTANT_MEM

View File

@ -90,6 +90,39 @@ class ConvPowerVR : public GPUOperation {
WeightsUploadType weights_upload_type;
bool x_kernel_is_1;
bool y_kernel_is_1;
bool IsPrivateMemBroadcast() const {
return weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
}
int GetSimdSize() const {
if (weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
return 8;
} else if (weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
return 16;
} else if (weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
return 32;
} else if (weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
return 64;
} else if (weights_upload_type ==
WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
return 128;
}
return 1;
}
};
ConvPowerVR(const OperationDef& definition,