TuningType moved to gpu/common/task.

PiperOrigin-RevId: 342770347
Change-Id: If0fd6174fb650002e9262fe13cb5a409989a9593
This commit is contained in:
Raman Sarokin 2020-11-16 19:08:01 -08:00 committed by TensorFlower Gardener
parent f0c976a27f
commit faad89dadf
15 changed files with 47 additions and 59 deletions

View File

@ -114,10 +114,11 @@ absl::Status ClOperation::CompileDeserialized(
*creation_context.context, *creation_context.device, &kernel_);
}
absl::Status ClOperation::Tune(const TuningParameters& params) {
absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
ProfilingCommandQueue* profiling_queue) {
std::vector<int3> possible_work_groups;
operation_->GetPossibleKernelWorkGroups(params.tuning_type, *params.info,
kernel_.info_, &possible_work_groups);
operation_->GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_.info_,
&possible_work_groups);
if (possible_work_groups.empty()) {
return absl::NotFoundError(
"Can not found work_group size to launch kernel");
@ -137,8 +138,8 @@ absl::Status ClOperation::Tune(const TuningParameters& params) {
}
RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
int best_work_group_index;
RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex(
kernel_, *params.info, work_groups_count, possible_work_groups,
RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
kernel_, gpu_info, work_groups_count, possible_work_groups,
&best_work_group_index));
operation_->work_group_size_ = possible_work_groups[best_work_group_index];
operation_->work_groups_count_ = GetWorkGroupsCount(

View File

@ -71,7 +71,8 @@ class ClOperation {
operation_->work_group_size_);
}
absl::Status Tune(const TuningParameters& params);
absl::Status Tune(TuningType tuning_type, const GpuInfo& gpu_info,
ProfilingCommandQueue* profiling_queue);
absl::Status Compile(const CreationContext& creation_context);

View File

@ -183,21 +183,20 @@ absl::Status InferenceContext::InitFromGraph(
RETURN_IF_ERROR(Compile(creation_context));
RETURN_IF_ERROR(UpdateParams());
TuningParameters tuning_parameters;
tuning_parameters.queue = env->profiling_queue();
tuning_parameters.info = &env->device().info_;
TuningType tuning_type = TuningType::kExhaustive;
if (create_info.hints.Check(ModelHints::kFastTuning)) {
tuning_parameters.tuning_type = TuningType::FAST;
tuning_type = TuningType::kFast;
}
if (tuning_parameters.info->IsMali()) {
const MaliInfo& info = tuning_parameters.info->mali_info;
if (env->device().GetInfo().IsMali()) {
const MaliInfo& info = env->device().GetInfo().mali_info;
if (info.IsMaliT6xx()) {
// Mali T628 hangs forever in clFinish when used profiling queue
// TuningType::FAST does not use profiling queue.
tuning_parameters.tuning_type = TuningType::FAST;
tuning_type = TuningType::kFast;
}
}
RETURN_IF_ERROR(Tune(tuning_parameters));
RETURN_IF_ERROR(
Tune(tuning_type, env->device().GetInfo(), env->profiling_queue()));
if (serialized_model) {
for (auto& node : nodes_) {
@ -631,9 +630,12 @@ absl::Status InferenceContext::Compile(
return absl::OkStatus();
}
absl::Status InferenceContext::Tune(const TuningParameters& tuning_parameters) {
absl::Status InferenceContext::Tune(TuningType tuning_type,
const GpuInfo& gpu_info,
ProfilingCommandQueue* profiling_queue) {
for (auto& node : nodes_) {
RETURN_IF_ERROR(node.cl_operation.Tune(tuning_parameters));
RETURN_IF_ERROR(
node.cl_operation.Tune(tuning_type, gpu_info, profiling_queue));
}
return absl::OkStatus();
}

View File

@ -133,7 +133,8 @@ class InferenceContext {
void BindMemoryToOperations();
absl::Status Compile(const CreationContext& creation_context);
absl::Status Tune(const TuningParameters& tuning_parameters);
absl::Status Tune(TuningType tuning_type, const GpuInfo& gpu_info,
ProfilingCommandQueue* profiling_queue);
absl::Status UpdateParams();
// performance hacks

View File

@ -580,7 +580,6 @@ cc_library(
hdrs = ["fully_connected.h"],
deps = [
":gpu_operation",
":tuning_parameters",
":util",
"//tensorflow/lite/delegates/gpu/cl:buffer",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
@ -624,9 +623,9 @@ cc_library(
srcs = ["gpu_operation.cc"],
hdrs = ["gpu_operation.h"],
deps = [
":tuning_parameters",
":util",
":work_group_picking",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:device_info",
"//tensorflow/lite/delegates/gpu/cl:serialization_cc_fbs",
"//tensorflow/lite/delegates/gpu/common:access_type",
@ -638,6 +637,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
"//tensorflow/lite/delegates/gpu/common/task:gpu_tensor",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
"@com_google_absl//absl/strings",
],
)
@ -1204,15 +1204,6 @@ cc_test(
],
)
cc_library(
name = "tuning_parameters",
hdrs = ["tuning_parameters.h"],
deps = [
"//tensorflow/lite/delegates/gpu/cl:cl_command_queue",
"//tensorflow/lite/delegates/gpu/cl:device_info",
],
)
cc_library(
name = "resize",
srcs = ["resize.cc"],
@ -1306,12 +1297,12 @@ cc_library(
srcs = ["work_group_picking.cc"],
hdrs = ["work_group_picking.h"],
deps = [
":tuning_parameters",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:device_info",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common:workgroup_selection",
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
],
)

View File

@ -54,7 +54,7 @@ class OpenClConverterImpl : public TensorObjectConverter {
const int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(),
tensor->Slices());
std::vector<int3> work_groups;
GetPossibleWorkGroupsConv(TuningType::FAST, gpu_info_, kernel_.info_, grid,
GetPossibleWorkGroupsConv(TuningType::kFast, gpu_info_, kernel_.info_, grid,
&work_groups);
const int3 work_group_size = work_groups[0];
const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size);

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"

View File

@ -19,8 +19,8 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/precision.h"
@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {

View File

@ -36,7 +36,6 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl:tensor",
"//tensorflow/lite/delegates/gpu/cl:texture2d",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
"//tensorflow/lite/delegates/gpu/cl/kernels:tuning_parameters",
"//tensorflow/lite/delegates/gpu/cl/kernels:util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"

View File

@ -285,11 +285,11 @@ void Winograd4x4To36::GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
switch (tuning_type) {
case TuningType::EXHAUSTIVE:
case TuningType::kExhaustive:
GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
work_groups);
return;
case TuningType::FAST:
case TuningType::kFast:
default:
work_groups->push_back(SelectBestWorkGroup(kernel_info));
return;
@ -481,11 +481,11 @@ void Winograd36To4x4::GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
switch (tuning_type) {
case TuningType::EXHAUSTIVE:
case TuningType::kExhaustive:
GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
work_groups);
return;
case TuningType::FAST:
case TuningType::kFast:
default:
work_groups->push_back(SelectBestWorkGroup(kernel_info));
return;

View File

@ -252,11 +252,11 @@ void GetPossibleWorkGroups(TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, const int3& grid,
std::vector<int3>* work_groups) {
switch (tuning_type) {
case TuningType::FAST:
case TuningType::kFast:
work_groups->push_back(
GetWorkGroup(grid, kernel_info.max_work_group_size));
return;
case TuningType::EXHAUSTIVE: {
case TuningType::kExhaustive: {
GetWorkGroupsAlignedToGrid(gpu_info, kernel_info, grid, work_groups);
return;
}
@ -270,7 +270,7 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, const int3& grid,
std::vector<int3>* work_groups) {
switch (tuning_type) {
case TuningType::FAST: {
case TuningType::kFast: {
int max_z_size = 16;
if (gpu_info.IsAdreno()) {
max_z_size = gpu_info.adreno_info.IsAdreno3xx() ? 16 : 64;
@ -280,7 +280,7 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info,
GetWorkGroupConv(grid, kernel_info.max_work_group_size, max_z_size));
return;
}
case TuningType::EXHAUSTIVE: {
case TuningType::kExhaustive: {
GetWorkGroupsAlignedToGrid(gpu_info, kernel_info, grid, work_groups);
return;
}

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
#include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h"

View File

@ -100,6 +100,11 @@ cc_library(
],
)
cc_library(
name = "tuning_type",
hdrs = ["tuning_type.h"],
)
cc_library(
name = "util",
srcs = ["util.cc"],

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,26 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TUNING_PARAMETERS_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TUNING_PARAMETERS_H_
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TUNING_TYPE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TUNING_TYPE_H_
namespace tflite {
namespace gpu {
namespace cl {
enum class TuningType { EXHAUSTIVE, FAST };
enum class TuningType { kExhaustive, kFast };
struct TuningParameters {
ProfilingCommandQueue* queue;
const GpuInfo* info;
TuningType tuning_type = TuningType::EXHAUSTIVE;
};
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TUNING_PARAMETERS_H_
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TUNING_TYPE_H_