Enabling non trivial fast tuning for all vendors.

PiperOrigin-RevId: 297142417
Change-Id: I948e0ae458fbf59841d83d3c087c35c286daf269
This commit is contained in:
Raman Sarokin 2020-02-25 10:05:52 -08:00 committed by TensorFlower Gardener
parent 7c0ab406e3
commit b5f48a2ce0

View File

@ -248,13 +248,8 @@ Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel,
const int3& grid, int3* best_work_group) {
switch (params.tuning_type) {
case TuningType::FAST:
if (params.info->vendor != Vendor::QUALCOMM) {
*best_work_group = int3(8, 4, 1);
return OkStatus();
} else {
*best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize());
return OkStatus();
}
*best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize());
return OkStatus();
case TuningType::EXHAUSTIVE:
return GetBestWorkGroupAlignedToGrid(params, kernel, grid,
best_work_group);
@ -268,16 +263,16 @@ Status GetBestWorkGroupConv(const TuningParameters& params,
const CLKernel& kernel, const int3& grid,
int3* best_work_group) {
switch (params.tuning_type) {
case TuningType::FAST:
if (params.info->vendor != Vendor::QUALCOMM) {
*best_work_group = int3(8, 4, 1);
return OkStatus();
} else {
int max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64;
*best_work_group =
GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size);
return OkStatus();
case TuningType::FAST: {
int max_z_size = 16;
if (params.info->vendor == Vendor::QUALCOMM) {
max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64;
}
max_z_size = std::min(max_z_size, params.info->max_work_group_sizes.z);
*best_work_group =
GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size);
return OkStatus();
}
case TuningType::EXHAUSTIVE:
return GetBestWorkGroupAlignedToGrid(params, kernel, grid,
best_work_group);