Improved performance of convolution on AMD.
PiperOrigin-RevId: 293066841 Change-Id: Ifb8927606cea81897b9dee25b70918c585ba8c0b
This commit is contained in:
parent
2b3dddff9c
commit
7a5bd40b13
@ -169,7 +169,8 @@ Status ConvPowerVR::Tune(const TuningParameters& params) {
|
|||||||
if (conv_params_.weights_upload_type ==
|
if (conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
||||||
conv_params_.weights_upload_type ==
|
conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
WeightsUploadType::LOCAL_MEM_BY_THREADS ||
|
||||||
|
conv_params_.fixed_work_group_size) {
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
if (conv_params_.work_group_launch_order[0] == 0 &&
|
if (conv_params_.work_group_launch_order[0] == 0 &&
|
||||||
@ -212,9 +213,15 @@ std::string GenerateConvPowerVR1x1(
|
|||||||
conv_params.weights_upload_type ==
|
conv_params.weights_upload_type ==
|
||||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
||||||
|
|
||||||
|
const std::string weights_space =
|
||||||
|
conv_params.weights_upload_type ==
|
||||||
|
ConvPowerVR::WeightsUploadType::CONSTANT_MEM
|
||||||
|
? "__constant"
|
||||||
|
: "__global";
|
||||||
|
|
||||||
const int3 work_group_size = conv_params.work_group_size;
|
const int3 work_group_size = conv_params.work_group_size;
|
||||||
const int3 block_size = conv_params.block_size;
|
const int3 block_size = conv_params.block_size;
|
||||||
if (need_local_mem) { // we use fixed workgroup size when use local mem
|
if (conv_params.fixed_work_group_size) {
|
||||||
c += "__attribute__((reqd_work_group_size(" +
|
c += "__attribute__((reqd_work_group_size(" +
|
||||||
std::to_string(work_group_size.x) + ", " +
|
std::to_string(work_group_size.x) + ", " +
|
||||||
std::to_string(work_group_size.y) + ", " +
|
std::to_string(work_group_size.y) + ", " +
|
||||||
@ -222,8 +229,8 @@ std::string GenerateConvPowerVR1x1(
|
|||||||
}
|
}
|
||||||
c += "__kernel void main_function(\n";
|
c += "__kernel void main_function(\n";
|
||||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||||
c += " __global ACCUM_FLT4* filters_buffer, \n";
|
c += " " + weights_space + " ACCUM_FLT4* filters_buffer, \n";
|
||||||
c += " __global ACCUM_FLT4* biases \n";
|
c += " " + weights_space + " ACCUM_FLT4* biases \n";
|
||||||
c += GetArgsDeclaration(linked_operations);
|
c += GetArgsDeclaration(linked_operations);
|
||||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||||
if (!is1x1) {
|
if (!is1x1) {
|
||||||
@ -301,14 +308,18 @@ std::string GenerateConvPowerVR1x1(
|
|||||||
"];\n";
|
"];\n";
|
||||||
}
|
}
|
||||||
if (conv_params.weights_upload_type ==
|
if (conv_params.weights_upload_type ==
|
||||||
ConvPowerVR::WeightsUploadType::GLOBAL_MEM) {
|
ConvPowerVR::WeightsUploadType::GLOBAL_MEM ||
|
||||||
c += " __global ACCUM_FLT4* weights_cache;\n";
|
conv_params.weights_upload_type ==
|
||||||
|
ConvPowerVR::WeightsUploadType::CONSTANT_MEM) {
|
||||||
|
c += " " + weights_space + " ACCUM_FLT4* weights_cache;\n";
|
||||||
}
|
}
|
||||||
if (is1x1) {
|
if (is1x1) {
|
||||||
c += " __global ACCUM_FLT4* filters_loc = filters_buffer + Z * 4 * "
|
c += " " + weights_space +
|
||||||
|
" ACCUM_FLT4* filters_loc = filters_buffer + Z * 4 * "
|
||||||
"src_size.z;\n";
|
"src_size.z;\n";
|
||||||
} else {
|
} else {
|
||||||
c += " __global ACCUM_FLT4* filters_loc = filters_buffer + Z * 4 * "
|
c += " " + weights_space +
|
||||||
|
" ACCUM_FLT4* filters_loc = filters_buffer + Z * 4 * "
|
||||||
"src_size.z * kernel_dilation.x * kernel_dilation.y;\n";
|
"src_size.z * kernel_dilation.x * kernel_dilation.y;\n";
|
||||||
}
|
}
|
||||||
if (buffer_type) {
|
if (buffer_type) {
|
||||||
@ -445,7 +456,7 @@ std::string GenerateConvPowerVR1x1(
|
|||||||
"weights_cache", "filters_loc",
|
"weights_cache", "filters_loc",
|
||||||
/*global_offset_name*/ "", "lid", total_work_items,
|
/*global_offset_name*/ "", "lid", total_work_items,
|
||||||
block_size.z * 4 * conv_params.src_depth_loop_size);
|
block_size.z * 4 * conv_params.src_depth_loop_size);
|
||||||
} else { // GLOBAL_MEM
|
} else { // GLOBAL_MEM/CONSTANT_MEM
|
||||||
c += " weights_cache = filters_loc;\n";
|
c += " weights_cache = filters_loc;\n";
|
||||||
}
|
}
|
||||||
read_src();
|
read_src();
|
||||||
@ -477,7 +488,7 @@ std::string GenerateConvPowerVR1x1(
|
|||||||
c += GenerateUploadByThreads("weights_cache", "biases", "Z", "lid",
|
c += GenerateUploadByThreads("weights_cache", "biases", "Z", "lid",
|
||||||
total_work_items, block_size.z);
|
total_work_items, block_size.z);
|
||||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||||
} else { // GLOBAL_MEM
|
} else { // GLOBAL_MEM/CONSTANT_MEM
|
||||||
c += " weights_cache = biases + Z;\n";
|
c += " weights_cache = biases + Z;\n";
|
||||||
}
|
}
|
||||||
if (need_local_mem) {
|
if (need_local_mem) {
|
||||||
@ -528,6 +539,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
|||||||
conv_params.block_size = int3(1, 1, 4);
|
conv_params.block_size = int3(1, 1, 4);
|
||||||
conv_params.work_group_size = int3(8, 4, 1);
|
conv_params.work_group_size = int3(8, 4, 1);
|
||||||
conv_params.work_group_launch_order = int3(2, 0, 1);
|
conv_params.work_group_launch_order = int3(2, 0, 1);
|
||||||
|
conv_params.fixed_work_group_size = true;
|
||||||
conv_params.src_depth_loop_size = 1;
|
conv_params.src_depth_loop_size = 1;
|
||||||
conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
|
conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
|
||||||
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||||
@ -547,6 +559,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
|||||||
conv_params.block_size = int3(1, 1, 4);
|
conv_params.block_size = int3(1, 1, 4);
|
||||||
conv_params.work_group_size = int3(8, 4, 1);
|
conv_params.work_group_size = int3(8, 4, 1);
|
||||||
conv_params.work_group_launch_order = int3(2, 0, 1);
|
conv_params.work_group_launch_order = int3(2, 0, 1);
|
||||||
|
conv_params.fixed_work_group_size = true;
|
||||||
conv_params.src_depth_loop_size = 1;
|
conv_params.src_depth_loop_size = 1;
|
||||||
conv_params.weights_upload_type =
|
conv_params.weights_upload_type =
|
||||||
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
||||||
@ -581,10 +594,33 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
|||||||
conv_params.block_size.x = 2;
|
conv_params.block_size.x = 2;
|
||||||
conv_params.work_group_size = int3(4, 8, 1);
|
conv_params.work_group_size = int3(4, 8, 1);
|
||||||
}
|
}
|
||||||
|
} else if (device.IsAMD()) {
|
||||||
|
conv_params.block_size = int3(2, 1, 1);
|
||||||
|
if (x_kernel_is_1 && y_kernel_is_1) {
|
||||||
|
conv_params.block_size.y = 2;
|
||||||
|
}
|
||||||
|
conv_params.work_group_size = int3(8, 4, 1);
|
||||||
|
conv_params.work_group_launch_order = int3(2, 0, 1);
|
||||||
|
conv_params.fixed_work_group_size = true;
|
||||||
|
conv_params.src_depth_loop_size = 1;
|
||||||
|
conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM;
|
||||||
|
if (dst_depth % 8 == 0 || dst_depth >= 32) {
|
||||||
|
conv_params.block_size.z = 8;
|
||||||
|
} else if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||||
|
conv_params.block_size.z = 4;
|
||||||
|
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||||
|
conv_params.block_size.z = 2;
|
||||||
|
} else {
|
||||||
|
conv_params.block_size.z = 1;
|
||||||
|
}
|
||||||
|
if (src_depth % 2 == 0 && src_depth >= 16) {
|
||||||
|
conv_params.src_depth_loop_size = 2;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
conv_params.block_size = int3(1, 1, 4);
|
conv_params.block_size = int3(1, 1, 4);
|
||||||
conv_params.work_group_size = int3(8, 4, 1);
|
conv_params.work_group_size = int3(8, 4, 1);
|
||||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||||
|
conv_params.fixed_work_group_size = false;
|
||||||
conv_params.src_depth_loop_size = 1;
|
conv_params.src_depth_loop_size = 1;
|
||||||
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
|
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
|
||||||
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||||
|
|||||||
@ -54,12 +54,14 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
LOCAL_MEM_ASYNC_SUBGROUP, // we use it for PowerVR with workgroup size = 32
|
LOCAL_MEM_ASYNC_SUBGROUP, // we use it for PowerVR with workgroup size = 32
|
||||||
LOCAL_MEM_BY_THREADS,
|
LOCAL_MEM_BY_THREADS,
|
||||||
GLOBAL_MEM,
|
GLOBAL_MEM,
|
||||||
|
CONSTANT_MEM,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConvParams {
|
struct ConvParams {
|
||||||
int3 block_size;
|
int3 block_size;
|
||||||
int3 work_group_size;
|
int3 work_group_size;
|
||||||
int3 work_group_launch_order;
|
int3 work_group_launch_order;
|
||||||
|
bool fixed_work_group_size;
|
||||||
int src_depth_loop_size;
|
int src_depth_loop_size;
|
||||||
WeightsUploadType weights_upload_type;
|
WeightsUploadType weights_upload_type;
|
||||||
bool x_kernel_is_1;
|
bool x_kernel_is_1;
|
||||||
|
|||||||
@ -106,6 +106,7 @@ Status SelectConvolution(const Convolution2DAttributes& attr,
|
|||||||
return SelectConvolutionAdreno(attr, dst_shape, creation_context, op_def,
|
return SelectConvolutionAdreno(attr, dst_shape, creation_context, op_def,
|
||||||
hints, ptr);
|
hints, ptr);
|
||||||
case Vendor::POWERVR:
|
case Vendor::POWERVR:
|
||||||
|
case Vendor::AMD:
|
||||||
return SelectConvolutionPowerVR(attr, creation_context, op_def, ptr);
|
return SelectConvolutionPowerVR(attr, creation_context, op_def, ptr);
|
||||||
case Vendor::NVIDIA:
|
case Vendor::NVIDIA:
|
||||||
return SelectConvolutionNVidia(attr, creation_context, op_def, ptr);
|
return SelectConvolutionNVidia(attr, creation_context, op_def, ptr);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user