Improved selection of block size for ConvBuffer1x1
PiperOrigin-RevId: 300460634 Change-Id: I1c5ef87614c84a1f638d9f48225f411dfa79fc96
This commit is contained in:
parent
dba1d5fa93
commit
a4f2960e0e
@ -297,10 +297,19 @@ bool MaliInfo::IsMidgard() const {
|
||||
return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
|
||||
}
|
||||
|
||||
bool MaliInfo::IsBifrost() const {
|
||||
bool MaliInfo::IsBifrostGen1() const {
|
||||
return gpu_version == MaliGPU::G31 || gpu_version == MaliGPU::G51 ||
|
||||
gpu_version == MaliGPU::G71 || gpu_version == MaliGPU::G52 ||
|
||||
gpu_version == MaliGPU::G72 || gpu_version == MaliGPU::G76;
|
||||
gpu_version == MaliGPU::G71;
|
||||
}
|
||||
|
||||
bool MaliInfo::IsBifrostGen2() const {
|
||||
return gpu_version == MaliGPU::G52 || gpu_version == MaliGPU::G72;
|
||||
}
|
||||
|
||||
bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGPU::G76; }
|
||||
|
||||
bool MaliInfo::IsBifrost() const {
|
||||
return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
|
||||
}
|
||||
|
||||
bool MaliInfo::IsValhall() const {
|
||||
|
@ -94,6 +94,9 @@ struct MaliInfo {
|
||||
bool IsMaliT7xx() const;
|
||||
bool IsMaliT8xx() const;
|
||||
bool IsMidgard() const;
|
||||
bool IsBifrostGen1() const;
|
||||
bool IsBifrostGen2() const;
|
||||
bool IsBifrostGen3() const;
|
||||
bool IsBifrost() const;
|
||||
bool IsValhall() const;
|
||||
};
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
||||
|
||||
#include <array>
|
||||
#include <cfloat>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
@ -201,6 +202,74 @@ std::string GenerateConvBuffer1x1(
|
||||
return c;
|
||||
}
|
||||
|
||||
// task_size as amount of FLT4 processed elements.
|
||||
int GetRecommendedBlockSizeForConv(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
int task_size) {
|
||||
const float task_size_per_cu =
|
||||
task_size / static_cast<float>(device.GetInfo().compute_units_count);
|
||||
int block_size = 1;
|
||||
float threshold_1 = FLT_MAX;
|
||||
float threshold_2 = FLT_MAX;
|
||||
float threshold_4 = FLT_MAX;
|
||||
if (!device.IsMali()) {
|
||||
return 1;
|
||||
}
|
||||
MaliInfo mali_info = device.GetInfo().mali_info;
|
||||
switch (definition.precision) {
|
||||
case CalculationsPrecision::F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
threshold_4 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 6.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 3.0f;
|
||||
threshold_4 = 256.0f * 32.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 128.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 12.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (task_size_per_cu <= threshold_1) {
|
||||
block_size = 1;
|
||||
} else if (task_size_per_cu <= threshold_2) {
|
||||
block_size = 2;
|
||||
} else if (task_size_per_cu <= threshold_4) {
|
||||
block_size = 4;
|
||||
} else {
|
||||
block_size = 8;
|
||||
}
|
||||
return block_size;
|
||||
}
|
||||
|
||||
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const BHWC& shape, int src_depth,
|
||||
@ -211,15 +280,46 @@ ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
||||
if (!device.IsMali()) {
|
||||
return conv_params;
|
||||
}
|
||||
const int width = shape.w * shape.b;
|
||||
if (width % 2 == 0) {
|
||||
conv_params.element_size = 8;
|
||||
bool can_use_flt8 = (shape.w * shape.b) % 2 == 0 &&
|
||||
definition.precision != CalculationsPrecision::F32;
|
||||
bool is_midgard = device.IsMali() && device.GetInfo().mali_info.IsMidgard();
|
||||
if (is_midgard) {
|
||||
if (can_use_flt8) {
|
||||
conv_params.element_size = 8;
|
||||
}
|
||||
if (definition.precision == CalculationsPrecision::F16 || !can_use_flt8) {
|
||||
conv_params.block_size.x = 2;
|
||||
}
|
||||
return conv_params;
|
||||
}
|
||||
if (device.GetInfo().compute_units_count <= 4) {
|
||||
if (definition.precision == CalculationsPrecision::F16) {
|
||||
conv_params.block_size.x *= 2;
|
||||
|
||||
int task_size = shape.w * shape.b * shape.h * dst_depth;
|
||||
int block_size =
|
||||
GetRecommendedBlockSizeForConv(device, definition, task_size);
|
||||
|
||||
if (!can_use_flt8 && block_size > 4) {
|
||||
block_size = 4;
|
||||
}
|
||||
|
||||
if (can_use_flt8 && block_size >= 2) {
|
||||
conv_params.element_size = 8;
|
||||
block_size /= 2;
|
||||
}
|
||||
if (block_size == 4) {
|
||||
conv_params.block_size.x = 2;
|
||||
if (definition.precision == CalculationsPrecision::F32 && dst_depth < 32) {
|
||||
conv_params.block_size.y = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = 2;
|
||||
}
|
||||
} else if (block_size == 2) {
|
||||
if (dst_depth >= 32) {
|
||||
conv_params.block_size.z = 2;
|
||||
} else {
|
||||
conv_params.block_size.x = 2;
|
||||
}
|
||||
}
|
||||
|
||||
return conv_params;
|
||||
}
|
||||
|
||||
|
@ -324,7 +324,9 @@ ConvolutionTransposed::ConvolutionTransposed(
|
||||
}
|
||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||
if (dst_depth == 1 || dst_depth == 3) {
|
||||
block_size_.y *= block_size_.z;
|
||||
if (!device.IsMali()) {
|
||||
block_size_.y *= block_size_.z;
|
||||
}
|
||||
block_size_.z = 1;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user