Improved selection of block size for ConvBuffer1x1
PiperOrigin-RevId: 300460634 Change-Id: I1c5ef87614c84a1f638d9f48225f411dfa79fc96
This commit is contained in:
parent
dba1d5fa93
commit
a4f2960e0e
@ -297,10 +297,19 @@ bool MaliInfo::IsMidgard() const {
|
|||||||
return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
|
return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MaliInfo::IsBifrost() const {
|
bool MaliInfo::IsBifrostGen1() const {
|
||||||
return gpu_version == MaliGPU::G31 || gpu_version == MaliGPU::G51 ||
|
return gpu_version == MaliGPU::G31 || gpu_version == MaliGPU::G51 ||
|
||||||
gpu_version == MaliGPU::G71 || gpu_version == MaliGPU::G52 ||
|
gpu_version == MaliGPU::G71;
|
||||||
gpu_version == MaliGPU::G72 || gpu_version == MaliGPU::G76;
|
}
|
||||||
|
|
||||||
|
bool MaliInfo::IsBifrostGen2() const {
|
||||||
|
return gpu_version == MaliGPU::G52 || gpu_version == MaliGPU::G72;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGPU::G76; }
|
||||||
|
|
||||||
|
bool MaliInfo::IsBifrost() const {
|
||||||
|
return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MaliInfo::IsValhall() const {
|
bool MaliInfo::IsValhall() const {
|
||||||
|
@ -94,6 +94,9 @@ struct MaliInfo {
|
|||||||
bool IsMaliT7xx() const;
|
bool IsMaliT7xx() const;
|
||||||
bool IsMaliT8xx() const;
|
bool IsMaliT8xx() const;
|
||||||
bool IsMidgard() const;
|
bool IsMidgard() const;
|
||||||
|
bool IsBifrostGen1() const;
|
||||||
|
bool IsBifrostGen2() const;
|
||||||
|
bool IsBifrostGen3() const;
|
||||||
bool IsBifrost() const;
|
bool IsBifrost() const;
|
||||||
bool IsValhall() const;
|
bool IsValhall() const;
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <cfloat>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@ -201,6 +202,74 @@ std::string GenerateConvBuffer1x1(
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// task_size as amount of FLT4 processed elements.
|
||||||
|
int GetRecommendedBlockSizeForConv(const CLDevice& device,
|
||||||
|
const OperationDef& definition,
|
||||||
|
int task_size) {
|
||||||
|
const float task_size_per_cu =
|
||||||
|
task_size / static_cast<float>(device.GetInfo().compute_units_count);
|
||||||
|
int block_size = 1;
|
||||||
|
float threshold_1 = FLT_MAX;
|
||||||
|
float threshold_2 = FLT_MAX;
|
||||||
|
float threshold_4 = FLT_MAX;
|
||||||
|
if (!device.IsMali()) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
MaliInfo mali_info = device.GetInfo().mali_info;
|
||||||
|
switch (definition.precision) {
|
||||||
|
case CalculationsPrecision::F16:
|
||||||
|
if (mali_info.IsBifrostGen1()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 4.0f;
|
||||||
|
threshold_4 = 256.0f * 8.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen2()) {
|
||||||
|
threshold_1 = 256.0f * 2.0f;
|
||||||
|
threshold_2 = 256.0f * 8.0f;
|
||||||
|
threshold_4 = 256.0f * 16.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen3()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 6.0f;
|
||||||
|
threshold_4 = 256.0f * 16.0f;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case CalculationsPrecision::F32_F16:
|
||||||
|
if (mali_info.IsBifrostGen1()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 3.0f;
|
||||||
|
threshold_4 = 256.0f * 32.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen2()) {
|
||||||
|
threshold_1 = 256.0f * 2.0f;
|
||||||
|
threshold_2 = 256.0f * 8.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen3()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 8.0f;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case CalculationsPrecision::F32:
|
||||||
|
if (mali_info.IsBifrostGen1()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 4.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen2()) {
|
||||||
|
threshold_1 = 128.0f;
|
||||||
|
threshold_2 = 256.0f * 4.0f;
|
||||||
|
} else if (mali_info.IsBifrostGen3()) {
|
||||||
|
threshold_1 = 256.0f;
|
||||||
|
threshold_2 = 256.0f * 12.0f;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (task_size_per_cu <= threshold_1) {
|
||||||
|
block_size = 1;
|
||||||
|
} else if (task_size_per_cu <= threshold_2) {
|
||||||
|
block_size = 2;
|
||||||
|
} else if (task_size_per_cu <= threshold_4) {
|
||||||
|
block_size = 4;
|
||||||
|
} else {
|
||||||
|
block_size = 8;
|
||||||
|
}
|
||||||
|
return block_size;
|
||||||
|
}
|
||||||
|
|
||||||
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const BHWC& shape, int src_depth,
|
const BHWC& shape, int src_depth,
|
||||||
@ -211,15 +280,46 @@ ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
|||||||
if (!device.IsMali()) {
|
if (!device.IsMali()) {
|
||||||
return conv_params;
|
return conv_params;
|
||||||
}
|
}
|
||||||
const int width = shape.w * shape.b;
|
bool can_use_flt8 = (shape.w * shape.b) % 2 == 0 &&
|
||||||
if (width % 2 == 0) {
|
definition.precision != CalculationsPrecision::F32;
|
||||||
conv_params.element_size = 8;
|
bool is_midgard = device.IsMali() && device.GetInfo().mali_info.IsMidgard();
|
||||||
|
if (is_midgard) {
|
||||||
|
if (can_use_flt8) {
|
||||||
|
conv_params.element_size = 8;
|
||||||
|
}
|
||||||
|
if (definition.precision == CalculationsPrecision::F16 || !can_use_flt8) {
|
||||||
|
conv_params.block_size.x = 2;
|
||||||
|
}
|
||||||
|
return conv_params;
|
||||||
}
|
}
|
||||||
if (device.GetInfo().compute_units_count <= 4) {
|
|
||||||
if (definition.precision == CalculationsPrecision::F16) {
|
int task_size = shape.w * shape.b * shape.h * dst_depth;
|
||||||
conv_params.block_size.x *= 2;
|
int block_size =
|
||||||
|
GetRecommendedBlockSizeForConv(device, definition, task_size);
|
||||||
|
|
||||||
|
if (!can_use_flt8 && block_size > 4) {
|
||||||
|
block_size = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (can_use_flt8 && block_size >= 2) {
|
||||||
|
conv_params.element_size = 8;
|
||||||
|
block_size /= 2;
|
||||||
|
}
|
||||||
|
if (block_size == 4) {
|
||||||
|
conv_params.block_size.x = 2;
|
||||||
|
if (definition.precision == CalculationsPrecision::F32 && dst_depth < 32) {
|
||||||
|
conv_params.block_size.y = 2;
|
||||||
|
} else {
|
||||||
|
conv_params.block_size.z = 2;
|
||||||
|
}
|
||||||
|
} else if (block_size == 2) {
|
||||||
|
if (dst_depth >= 32) {
|
||||||
|
conv_params.block_size.z = 2;
|
||||||
|
} else {
|
||||||
|
conv_params.block_size.x = 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv_params;
|
return conv_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -324,7 +324,9 @@ ConvolutionTransposed::ConvolutionTransposed(
|
|||||||
}
|
}
|
||||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||||
if (dst_depth == 1 || dst_depth == 3) {
|
if (dst_depth == 1 || dst_depth == 3) {
|
||||||
block_size_.y *= block_size_.z;
|
if (!device.IsMali()) {
|
||||||
|
block_size_.y *= block_size_.z;
|
||||||
|
}
|
||||||
block_size_.z = 1;
|
block_size_.z = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user