Improved selection of block size for ConvBuffer1x1

PiperOrigin-RevId: 300460634
Change-Id: I1c5ef87614c84a1f638d9f48225f411dfa79fc96
This commit is contained in:
Raman Sarokin 2020-03-11 19:42:28 -07:00 committed by TensorFlower Gardener
parent dba1d5fa93
commit a4f2960e0e
4 changed files with 124 additions and 10 deletions

View File

@ -297,10 +297,19 @@ bool MaliInfo::IsMidgard() const {
return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
}
bool MaliInfo::IsBifrost() const {
bool MaliInfo::IsBifrostGen1() const {
return gpu_version == MaliGPU::G31 || gpu_version == MaliGPU::G51 ||
gpu_version == MaliGPU::G71 || gpu_version == MaliGPU::G52 ||
gpu_version == MaliGPU::G72 || gpu_version == MaliGPU::G76;
gpu_version == MaliGPU::G71;
}
bool MaliInfo::IsBifrostGen2() const {
return gpu_version == MaliGPU::G52 || gpu_version == MaliGPU::G72;
}
bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGPU::G76; }
bool MaliInfo::IsBifrost() const {
return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
}
bool MaliInfo::IsValhall() const {

View File

@ -94,6 +94,9 @@ struct MaliInfo {
bool IsMaliT7xx() const;
bool IsMaliT8xx() const;
bool IsMidgard() const;
bool IsBifrostGen1() const;
bool IsBifrostGen2() const;
bool IsBifrostGen3() const;
bool IsBifrost() const;
bool IsValhall() const;
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
#include <array>
#include <cfloat>
#include <string>
#include <utility>
@ -201,6 +202,74 @@ std::string GenerateConvBuffer1x1(
return c;
}
// task_size as amount of FLT4 processed elements.
int GetRecommendedBlockSizeForConv(const CLDevice& device,
const OperationDef& definition,
int task_size) {
const float task_size_per_cu =
task_size / static_cast<float>(device.GetInfo().compute_units_count);
int block_size = 1;
float threshold_1 = FLT_MAX;
float threshold_2 = FLT_MAX;
float threshold_4 = FLT_MAX;
if (!device.IsMali()) {
return 1;
}
MaliInfo mali_info = device.GetInfo().mali_info;
switch (definition.precision) {
case CalculationsPrecision::F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
threshold_4 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
threshold_4 = 256.0f * 16.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 6.0f;
threshold_4 = 256.0f * 16.0f;
}
break;
case CalculationsPrecision::F32_F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 3.0f;
threshold_4 = 256.0f * 32.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 8.0f;
}
break;
case CalculationsPrecision::F32:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 128.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 12.0f;
}
break;
}
if (task_size_per_cu <= threshold_1) {
block_size = 1;
} else if (task_size_per_cu <= threshold_2) {
block_size = 2;
} else if (task_size_per_cu <= threshold_4) {
block_size = 4;
} else {
block_size = 8;
}
return block_size;
}
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
const OperationDef& definition,
const BHWC& shape, int src_depth,
@ -211,15 +280,46 @@ ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
if (!device.IsMali()) {
return conv_params;
}
const int width = shape.w * shape.b;
if (width % 2 == 0) {
conv_params.element_size = 8;
bool can_use_flt8 = (shape.w * shape.b) % 2 == 0 &&
definition.precision != CalculationsPrecision::F32;
bool is_midgard = device.IsMali() && device.GetInfo().mali_info.IsMidgard();
if (is_midgard) {
if (can_use_flt8) {
conv_params.element_size = 8;
}
if (definition.precision == CalculationsPrecision::F16 || !can_use_flt8) {
conv_params.block_size.x = 2;
}
return conv_params;
}
if (device.GetInfo().compute_units_count <= 4) {
if (definition.precision == CalculationsPrecision::F16) {
conv_params.block_size.x *= 2;
int task_size = shape.w * shape.b * shape.h * dst_depth;
int block_size =
GetRecommendedBlockSizeForConv(device, definition, task_size);
if (!can_use_flt8 && block_size > 4) {
block_size = 4;
}
if (can_use_flt8 && block_size >= 2) {
conv_params.element_size = 8;
block_size /= 2;
}
if (block_size == 4) {
conv_params.block_size.x = 2;
if (definition.precision == CalculationsPrecision::F32 && dst_depth < 32) {
conv_params.block_size.y = 2;
} else {
conv_params.block_size.z = 2;
}
} else if (block_size == 2) {
if (dst_depth >= 32) {
conv_params.block_size.z = 2;
} else {
conv_params.block_size.x = 2;
}
}
return conv_params;
}

View File

@ -324,7 +324,9 @@ ConvolutionTransposed::ConvolutionTransposed(
}
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
if (dst_depth == 1 || dst_depth == 3) {
block_size_.y *= block_size_.z;
if (!device.IsMali()) {
block_size_.y *= block_size_.z;
}
block_size_.z = 1;
}
}