Some fixes for private mem broadcast convolution.
PiperOrigin-RevId: 316766133 Change-Id: Icf2273be31c299c49664b9b9fe17df55a0c53693
This commit is contained in:
parent
d4c2030375
commit
ef68eff537
@ -196,6 +196,9 @@ absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
||||
creation_context.device->IsPowerVR()) {
|
||||
options.push_back(CompilerOptions::POWERVR_FP16);
|
||||
}
|
||||
if (conv_params_.IsPrivateMemBroadcast()) {
|
||||
options.push_back(CompilerOptions::CL_2_0);
|
||||
}
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", options, *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
@ -311,37 +314,10 @@ std::string GenerateConv(
|
||||
const int local_mem_size =
|
||||
conv_params.block_size.z * 4 * conv_params.src_depth_loop_size;
|
||||
|
||||
const bool use_simd_broadcast =
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
|
||||
const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast();
|
||||
const int simd_size = conv_params.GetSimdSize();
|
||||
|
||||
int simd_size = 1;
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
|
||||
simd_size = 8;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
|
||||
simd_size = 16;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
|
||||
simd_size = 32;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
|
||||
simd_size = 64;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
|
||||
simd_size = 128;
|
||||
}
|
||||
|
||||
bool late_oob_check = need_local_mem || use_simd_broadcast;
|
||||
const bool late_oob_check = need_local_mem || use_simd_broadcast;
|
||||
|
||||
const std::string weights_space =
|
||||
conv_params.weights_upload_type ==
|
||||
@ -355,6 +331,12 @@ std::string GenerateConv(
|
||||
const std::string weights_global_ptr =
|
||||
weights_space + " " + weights_data_type + "*";
|
||||
|
||||
if (use_simd_broadcast) {
|
||||
if (device.cl_version() == OpenCLVersion::CL_2_0) {
|
||||
c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
|
||||
}
|
||||
}
|
||||
|
||||
const int3 work_group_size = conv_params.work_group_size;
|
||||
const int3 block_size = conv_params.block_size;
|
||||
if (conv_params.fixed_work_group_size) {
|
||||
@ -364,7 +346,7 @@ std::string GenerateConv(
|
||||
std::to_string(work_group_size.z) + ")))\n";
|
||||
}
|
||||
if (use_simd_broadcast && device.IsIntel()) {
|
||||
c += "__attribute__((intel_reqd_work_group_size(" +
|
||||
c += "__attribute__((intel_reqd_sub_group_size(" +
|
||||
std::to_string(simd_size) + ")))\n";
|
||||
}
|
||||
c += "__kernel void main_function(\n";
|
||||
@ -408,6 +390,9 @@ std::string GenerateConv(
|
||||
std::to_string(work_group_size.x) + " + get_local_id(0);\n";
|
||||
}
|
||||
}
|
||||
if (use_simd_broadcast) {
|
||||
c += " int simd_id = get_sub_group_local_id();\n";
|
||||
}
|
||||
for (int z = 0; z < block_size.z; ++z) {
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
@ -555,17 +540,36 @@ std::string GenerateConv(
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
std::string w_val = "weights_cache[" +
|
||||
std::to_string(z * 4 + ch + shared_offset) +
|
||||
"]";
|
||||
if (use_simd_broadcast) {
|
||||
int simd_id = (z * 4 + ch + shared_offset) / simd_size;
|
||||
int thread_id = (z * 4 + ch + shared_offset) % simd_size;
|
||||
w_val = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
|
||||
", " + std::to_string(thread_id) + "u)";
|
||||
std::string w_val_x = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".x, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
std::string w_val_y = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".y, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
std::string w_val_z = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".z, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
std::string w_val_w = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".w, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
c += " r" + std::to_string(z) + id + ".x += " + w_val_x +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".y += " + w_val_y +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".z += " + w_val_z +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".w += " + w_val_w +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
} else {
|
||||
std::string w_val = "weights_cache[" +
|
||||
std::to_string(z * 4 + ch + shared_offset) +
|
||||
"]";
|
||||
c += " r" + std::to_string(z) + id + " += " + w_val +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
}
|
||||
c += " r" + std::to_string(z) + id + " += " + w_val +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -608,16 +612,15 @@ std::string GenerateConv(
|
||||
int parts = local_mem_size / simd_size;
|
||||
int reminder = local_mem_size % simd_size;
|
||||
for (int i = 0; i < parts; ++i) {
|
||||
c += " FLT4 simd_w" + std::to_string(i) +
|
||||
" = filters_loc[get_sub_group_local_id() + " +
|
||||
c += " FLT4 simd_w" + std::to_string(i) + " = filters_loc[simd_id + " +
|
||||
std::to_string(i * simd_size) + "];\n";
|
||||
}
|
||||
if (reminder) {
|
||||
c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
|
||||
c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
|
||||
c += " simd_w" + std::to_string(parts) +
|
||||
" = filters_loc[get_sub_group_local_id() + " +
|
||||
std::to_string(parts * simd_size) + "];\n";
|
||||
" = filters_loc[simd_id + " + std::to_string(parts * simd_size) +
|
||||
"];\n";
|
||||
c += " }\n";
|
||||
}
|
||||
} else { // GLOBAL_MEM/CONSTANT_MEM
|
||||
|
@ -90,6 +90,39 @@ class ConvPowerVR : public GPUOperation {
|
||||
WeightsUploadType weights_upload_type;
|
||||
bool x_kernel_is_1;
|
||||
bool y_kernel_is_1;
|
||||
|
||||
bool IsPrivateMemBroadcast() const {
|
||||
return weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
|
||||
weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
|
||||
weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
|
||||
weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
|
||||
weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
|
||||
}
|
||||
|
||||
int GetSimdSize() const {
|
||||
if (weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
|
||||
return 8;
|
||||
} else if (weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
|
||||
return 16;
|
||||
} else if (weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
|
||||
return 32;
|
||||
} else if (weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
|
||||
return 64;
|
||||
} else if (weights_upload_type ==
|
||||
WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
|
||||
return 128;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
|
Loading…
Reference in New Issue
Block a user