Added new ways of weights uploading.

PiperOrigin-RevId: 316556676
Change-Id: I343e4f6461a26a7d921699b23ee3ccf65ecb3bee
This commit is contained in:
Raman Sarokin 2020-06-15 15:23:29 -07:00 committed by TensorFlower Gardener
parent 072c2f5d0d
commit 5d4c6e105f
2 changed files with 82 additions and 24 deletions

View File

@ -308,6 +308,41 @@ std::string GenerateConv(
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
const int local_mem_size =
conv_params.block_size.z * 4 * conv_params.src_depth_loop_size;
const bool use_simd_broadcast =
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST;
int simd_size = 1;
if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
simd_size = 8;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST) {
simd_size = 16;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST) {
simd_size = 32;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD64_BROADCAST) {
simd_size = 64;
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::PRIVATE_MEM_SIMD128_BROADCAST) {
simd_size = 128;
}
bool late_oob_check = need_local_mem || use_simd_broadcast;
const std::string weights_space =
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::CONSTANT_MEM
@ -328,6 +363,10 @@ std::string GenerateConv(
std::to_string(work_group_size.y) + ", " +
std::to_string(work_group_size.z) + ")))\n";
}
if (use_simd_broadcast && device.IsIntel()) {
c += "__attribute__((intel_reqd_work_group_size(" +
std::to_string(simd_size) + ")))\n";
}
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
c += " " + weights_global_ptr + " filters_buffer, \n";
@ -355,7 +394,7 @@ std::string GenerateConv(
for (int y = 0; y < conv_params.block_size.y; ++y) {
dst_y[y] = "(Y + " + std::to_string(y) + ")";
}
if (!need_local_mem) {
if (!late_oob_check) {
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n";
c += " return;\n";
c += " }\n";
@ -396,13 +435,8 @@ std::string GenerateConv(
}
if (need_local_mem) {
c += " __local " + weights_data_type + " weights_cache[" +
std::to_string(block_size.z * 4 * conv_params.src_depth_loop_size) +
"];\n";
}
if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::GLOBAL_MEM ||
conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::CONSTANT_MEM) {
std::to_string(local_mem_size) + "];\n";
} else {
c += " " + weights_global_ptr + " weights_cache;\n";
}
if (is1x1) {
@ -521,9 +555,17 @@ std::string GenerateConv(
for (int y = 0; y < block_size.y; ++y) {
for (int x = 0; x < block_size.x; ++x) {
std::string id = std::to_string(y) + std::to_string(x);
c += " r" + std::to_string(z) + id + " += weights_cache[" +
std::to_string(z * 4 + ch + shared_offset) + "] * src" + id +
"." + channels[ch] + ";\n";
std::string w_val = "weights_cache[" +
std::to_string(z * 4 + ch + shared_offset) +
"]";
if (use_simd_broadcast) {
int simd_id = (z * 4 + ch + shared_offset) / simd_size;
int thread_id = (z * 4 + ch + shared_offset) % simd_size;
w_val = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
", " + std::to_string(thread_id) + "u)";
}
c += " r" + std::to_string(z) + id + " += " + w_val +
" * src" + id + "." + channels[ch] + ";\n";
}
}
}
@ -554,17 +596,30 @@ std::string GenerateConv(
work_group_size.x * work_group_size.y * work_group_size.z;
if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) {
c +=
GenerateAsyncUpload("weights_cache", "filters_loc",
/*global_offset_name*/ "",
block_size.z * 4 * conv_params.src_depth_loop_size);
c += GenerateAsyncUpload("weights_cache", "filters_loc",
/*global_offset_name*/ "", local_mem_size);
} else if (conv_params.weights_upload_type ==
ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
c += GenerateUploadByThreads(
"weights_cache", "filters_loc",
/*global_offset_name*/ "", "lid", total_work_items,
block_size.z * 4 * conv_params.src_depth_loop_size);
c += GenerateUploadByThreads("weights_cache", "filters_loc",
/*global_offset_name*/ "", "lid",
total_work_items, local_mem_size);
} else if (use_simd_broadcast) {
int parts = local_mem_size / simd_size;
int reminder = local_mem_size % simd_size;
for (int i = 0; i < parts; ++i) {
c += " FLT4 simd_w" + std::to_string(i) +
" = filters_loc[get_sub_group_local_id() + " +
std::to_string(i * simd_size) + "];\n";
}
if (reminder) {
c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
c += " simd_w" + std::to_string(parts) +
" = filters_loc[get_sub_group_local_id() + " +
std::to_string(parts * simd_size) + "];\n";
c += " }\n";
}
} else { // GLOBAL_MEM/CONSTANT_MEM
c += " weights_cache = filters_loc;\n";
}
@ -580,9 +635,7 @@ std::string GenerateConv(
conv_core(i * block_size.z * 4);
c += " s += 1;\n";
}
c += " filters_loc += " +
std::to_string(block_size.z * 4 * conv_params.src_depth_loop_size) +
";\n";
c += " filters_loc += " + std::to_string(local_mem_size) + ";\n";
c += " } while (s < src_size.z);\n";
if (!is1x1) {
c += " };\n";
@ -597,10 +650,10 @@ std::string GenerateConv(
c += GenerateUploadByThreads("weights_cache", "biases", "Z", "lid",
total_work_items, block_size.z);
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
} else { // GLOBAL_MEM/CONSTANT_MEM
} else {
c += " weights_cache = biases + Z;\n";
}
if (need_local_mem) {
if (late_oob_check) {
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n";
c += " return;\n";
c += " }\n";

View File

@ -64,6 +64,11 @@ class ConvPowerVR : public GPUOperation {
LOCAL_MEM_BY_THREADS,
GLOBAL_MEM,
CONSTANT_MEM,
PRIVATE_MEM_SIMD8_BROADCAST,
PRIVATE_MEM_SIMD16_BROADCAST,
PRIVATE_MEM_SIMD32_BROADCAST,
PRIVATE_MEM_SIMD64_BROADCAST,
PRIVATE_MEM_SIMD128_BROADCAST,
};
struct ConvParams {