Added new ways of weights uploading.
PiperOrigin-RevId: 316556676 Change-Id: I343e4f6461a26a7d921699b23ee3ccf65ecb3bee
This commit is contained in:
parent
072c2f5d0d
commit
5d4c6e105f
@ -308,6 +308,41 @@ std::string GenerateConv(
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
||||
|
||||
const int local_mem_size =
|
||||
conv_params.block_size.z * 4 * conv_params.src_depth_loop_size;
|
||||
|
||||
const bool use_simd_broadcast =
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
|
||||
|
||||
int simd_size = 1;
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
|
||||
simd_size = 8;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
|
||||
simd_size = 16;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
|
||||
simd_size = 32;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
|
||||
simd_size = 64;
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
|
||||
simd_size = 128;
|
||||
}
|
||||
|
||||
bool late_oob_check = need_local_mem || use_simd_broadcast;
|
||||
|
||||
const std::string weights_space =
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::CONSTANT_MEM
|
||||
@ -328,6 +363,10 @@ std::string GenerateConv(
|
||||
std::to_string(work_group_size.y) + ", " +
|
||||
std::to_string(work_group_size.z) + ")))\n";
|
||||
}
|
||||
if (use_simd_broadcast && device.IsIntel()) {
|
||||
c += "__attribute__((intel_reqd_work_group_size(" +
|
||||
std::to_string(simd_size) + ")))\n";
|
||||
}
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += " " + weights_global_ptr + " filters_buffer, \n";
|
||||
@ -355,7 +394,7 @@ std::string GenerateConv(
|
||||
for (int y = 0; y < conv_params.block_size.y; ++y) {
|
||||
dst_y[y] = "(Y + " + std::to_string(y) + ")";
|
||||
}
|
||||
if (!need_local_mem) {
|
||||
if (!late_oob_check) {
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n";
|
||||
c += " return;\n";
|
||||
c += " }\n";
|
||||
@ -396,13 +435,8 @@ std::string GenerateConv(
|
||||
}
|
||||
if (need_local_mem) {
|
||||
c += " __local " + weights_data_type + " weights_cache[" +
|
||||
std::to_string(block_size.z * 4 * conv_params.src_depth_loop_size) +
|
||||
"];\n";
|
||||
}
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::GLOBAL_MEM ||
|
||||
conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::CONSTANT_MEM) {
|
||||
std::to_string(local_mem_size) + "];\n";
|
||||
} else {
|
||||
c += " " + weights_global_ptr + " weights_cache;\n";
|
||||
}
|
||||
if (is1x1) {
|
||||
@ -521,9 +555,17 @@ std::string GenerateConv(
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
c += " r" + std::to_string(z) + id + " += weights_cache[" +
|
||||
std::to_string(z * 4 + ch + shared_offset) + "] * src" + id +
|
||||
"." + channels[ch] + ";\n";
|
||||
std::string w_val = "weights_cache[" +
|
||||
std::to_string(z * 4 + ch + shared_offset) +
|
||||
"]";
|
||||
if (use_simd_broadcast) {
|
||||
int simd_id = (z * 4 + ch + shared_offset) / simd_size;
|
||||
int thread_id = (z * 4 + ch + shared_offset) % simd_size;
|
||||
w_val = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
|
||||
", " + std::to_string(thread_id) + "u)";
|
||||
}
|
||||
c += " r" + std::to_string(z) + id + " += " + w_val +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -554,17 +596,30 @@ std::string GenerateConv(
|
||||
work_group_size.x * work_group_size.y * work_group_size.z;
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) {
|
||||
c +=
|
||||
GenerateAsyncUpload("weights_cache", "filters_loc",
|
||||
/*global_offset_name*/ "",
|
||||
block_size.z * 4 * conv_params.src_depth_loop_size);
|
||||
c += GenerateAsyncUpload("weights_cache", "filters_loc",
|
||||
/*global_offset_name*/ "", local_mem_size);
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += GenerateUploadByThreads(
|
||||
"weights_cache", "filters_loc",
|
||||
/*global_offset_name*/ "", "lid", total_work_items,
|
||||
block_size.z * 4 * conv_params.src_depth_loop_size);
|
||||
c += GenerateUploadByThreads("weights_cache", "filters_loc",
|
||||
/*global_offset_name*/ "", "lid",
|
||||
total_work_items, local_mem_size);
|
||||
} else if (use_simd_broadcast) {
|
||||
int parts = local_mem_size / simd_size;
|
||||
int reminder = local_mem_size % simd_size;
|
||||
for (int i = 0; i < parts; ++i) {
|
||||
c += " FLT4 simd_w" + std::to_string(i) +
|
||||
" = filters_loc[get_sub_group_local_id() + " +
|
||||
std::to_string(i * simd_size) + "];\n";
|
||||
}
|
||||
if (reminder) {
|
||||
c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
|
||||
c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
|
||||
c += " simd_w" + std::to_string(parts) +
|
||||
" = filters_loc[get_sub_group_local_id() + " +
|
||||
std::to_string(parts * simd_size) + "];\n";
|
||||
c += " }\n";
|
||||
}
|
||||
} else { // GLOBAL_MEM/CONSTANT_MEM
|
||||
c += " weights_cache = filters_loc;\n";
|
||||
}
|
||||
@ -580,9 +635,7 @@ std::string GenerateConv(
|
||||
conv_core(i * block_size.z * 4);
|
||||
c += " s += 1;\n";
|
||||
}
|
||||
c += " filters_loc += " +
|
||||
std::to_string(block_size.z * 4 * conv_params.src_depth_loop_size) +
|
||||
";\n";
|
||||
c += " filters_loc += " + std::to_string(local_mem_size) + ";\n";
|
||||
c += " } while (s < src_size.z);\n";
|
||||
if (!is1x1) {
|
||||
c += " };\n";
|
||||
@ -597,10 +650,10 @@ std::string GenerateConv(
|
||||
c += GenerateUploadByThreads("weights_cache", "biases", "Z", "lid",
|
||||
total_work_items, block_size.z);
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
} else { // GLOBAL_MEM/CONSTANT_MEM
|
||||
} else {
|
||||
c += " weights_cache = biases + Z;\n";
|
||||
}
|
||||
if (need_local_mem) {
|
||||
if (late_oob_check) {
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n";
|
||||
c += " return;\n";
|
||||
c += " }\n";
|
||||
|
@ -64,6 +64,11 @@ class ConvPowerVR : public GPUOperation {
|
||||
LOCAL_MEM_BY_THREADS,
|
||||
GLOBAL_MEM,
|
||||
CONSTANT_MEM,
|
||||
PRIVATE_MEM_SIMD8_BROADCAST,
|
||||
PRIVATE_MEM_SIMD16_BROADCAST,
|
||||
PRIVATE_MEM_SIMD32_BROADCAST,
|
||||
PRIVATE_MEM_SIMD64_BROADCAST,
|
||||
PRIVATE_MEM_SIMD128_BROADCAST,
|
||||
};
|
||||
|
||||
struct ConvParams {
|
||||
|
Loading…
Reference in New Issue
Block a user