From 1912ef16d67af82aff8a18a44cd555a919145046 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Wed, 22 Apr 2020 04:10:11 -0700 Subject: [PATCH] Add an option in GPU delegate to parameterize the #partitions to delegate. The default value of this parameter is 1, same w/ the current behavior. PiperOrigin-RevId: 307788075 Change-Id: I26bb65fcf049e82cc46e88f818b07ae245fbb1cc --- .../delegates/gpu/common/model_builder.cc | 15 +- .../lite/delegates/gpu/common/model_builder.h | 6 +- .../gpu/common/model_builder_test.cc | 181 ++++++++++++++++++ tensorflow/lite/delegates/gpu/delegate.cc | 88 +++++---- tensorflow/lite/delegates/gpu/delegate.h | 5 + .../tools/delegates/gpu_delegate_provider.cc | 2 + 6 files changed, 254 insertions(+), 43 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 00193bb0a68..19c0e59011f 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -2602,7 +2602,8 @@ bool IsAllAllowedTensors(TfLiteContext* context, const TfLiteIntArray* array, // TODO(impjdi): Check number of input/output tensors and their dimensions. // TODO(impjdi): Check ops' parameters. -TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) { +TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops, + int max_delegated_partitions) { delegates::IsNodeSupportedFn node_supported_fn = [=](TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, @@ -2633,11 +2634,11 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) { return TfLiteIntArrayCreate(0); } - // We simply get 1st largest partition, but we could later explore whether - // getting more partitions could lead to better performance, i.e. by - // parameterizing '1' here. + // By default, we simply get 1st largest partition as 'max_delegate_partions' + // is set to 1 by default. std::vector ops_to_replace = - partition_helper.GetNodesOfFirstNLargestPartitions(1); + partition_helper.GetNodesOfFirstNLargestPartitions( + max_delegated_partitions); if (!unsupported_nodes_info.empty()) { std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n"); @@ -2647,9 +2648,7 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) { if (!ops_to_replace.empty()) { absl::StrAppend( &error_message, ops_to_replace.size(), - " operations will run on the GPU (first node: ", - ops_to_replace.front(), ", last node: ", ops_to_replace.back(), - "), and the remaining ", + " operations will run on the GPU, and the remaining ", partition_helper.num_total_nodes() - ops_to_replace.size()); } else { absl::StrAppend(&error_message, diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index 4b2a2f51db3..1e5016d86b6 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -29,8 +29,12 @@ namespace gpu { // Validates which operations are supported and returns array of operations to // replace with GPU kernels. The caller must free the pointer on TfLiteIntArray. +// 'max_delegated_partitions' limits the maximum number of partitions to +// delegate as a graph could possibly have multiple partitions (each partition +// consists of a subset of ops) to be replaced. TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, - bool allow_quant_ops = false); + bool allow_quant_ops = false, + int max_delegated_partitions = 1); // Extracts TFLite delegate execution plan from the input TFLite context and // converts it into generic graph format. diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index 7b12f46453d..f0525e5e2c9 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -502,6 +502,187 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) { TfLiteIntArrayFree(ops_to_replace); } +class Interpreter2Fp32 : public DelegatedInterpreter { + public: + Interpreter2Fp32() : DelegatedInterpreter(4) { + void* builtin_data = malloc(sizeof(int)); + EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk); + EXPECT_EQ(interpreter_.SetInputs({0, 2, 4, 6}), kTfLiteOk); + EXPECT_EQ(interpreter_.SetOutputs({7}), kTfLiteOk); + + // Add a Dequantize Node with uint8 input. + const TfLiteRegistration reg_dequant = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/nullptr, + /*profiling_string=*/nullptr, + kTfLiteBuiltinDequantize}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr, + /*init_data_size=*/0, /*builtin_data=*/nullptr, + /*registration=*/®_dequant), + kTfLiteOk); + + // Add an ADD node that GPU delegate can parse. + const TfLiteRegistration reg_add0 = { + [](TfLiteContext* context, const char* buffer, size_t length) { + return reinterpret_cast(new int(1)); + }, + [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }, + nullptr, + nullptr, + nullptr, + kTfLiteBuiltinAdd}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{1, 2}, /*outputs=*/{3}, /*init_data=*/nullptr, + /*init_data_size=*/0, + /*builtin_data=*/builtin_data, + /*registration=*/®_add0), + kTfLiteOk); + + // Add a Pack Node that GPU delegate doesn't support + const TfLiteRegistration reg_pack = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/nullptr, + /*profiling_string=*/nullptr, + kTfLiteBuiltinPack}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{3, 4}, /*outputs=*/{5}, /*init_data=*/nullptr, + /*init_data_size=*/0, /*builtin_data=*/nullptr, + /*registration=*/®_pack), + kTfLiteOk); + + const TfLiteRegistration reg_add1 = { + [](TfLiteContext* context, const char* buffer, size_t length) { + return reinterpret_cast(new int[2]); + }, + [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }, + nullptr, + nullptr, + nullptr, + kTfLiteBuiltinAdd}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{5, 6}, /*outputs=*/{7}, /*init_data=*/nullptr, + /*init_data_size=*/0, + /*builtin_data=*/builtin_data, + /*registration=*/®_add1), + kTfLiteOk); + + std::vector dims = {1}; + TfLiteQuantization quantization; + quantization.type = kTfLiteNoQuantization; + EXPECT_EQ(interpreter_.SetTensorParametersReadWrite( + 0, TfLiteType::kTfLiteUInt8, "t0", dims, quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false), + kTfLiteOk); + + dims.push_back(2); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 6, TfLiteType::kTfLiteFloat32, "t6", dims, quantization, false), + kTfLiteOk); + + exec_plan()->data[0] = 0; + exec_plan()->data[1] = 1; + exec_plan()->data[2] = 2; + exec_plan()->data[3] = 3; + } +}; + +Interpreter2Fp32* interpreter2_fp32 = new Interpreter2Fp32(); + +TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) { + // A graph with a Dequant node with uint8 input, a Pack node are not pruned. + // As these ops are currently not supported on the GPU, they will be scheduled + // to run on the CPU while the remaining supported op Add on the GPU. + // + // t0 (uint8) -> Dequant(0) -> t1 (FP32) -> Add(1) -> t3 (FP32) -> PACK (2) + // t2 (FP32) -/ t4 (FP32) -/ + // PACK (2) -> t5 (FP32) -> Add(3) -> t7 + // -> t6 (FP32) -/ + // + TfLiteContext* context = interpreter2_fp32->context(); + + // These functions are meant to be called inside delegates. Swap out + // for similar functions to permit direct calling of GetOpsToReplace. + context->GetExecutionPlan = [](struct TfLiteContext* context, + TfLiteIntArray** execution_plan) { + *execution_plan = interpreter2_fp32->exec_plan(); + return kTfLiteOk; + }; + context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration) { + auto& node_and_reg = + interpreter2_fp32->nodes_and_registration()[node_index]; + *node = &node_and_reg.first; + *registration = &node_and_reg.second; + return kTfLiteOk; + }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter2_fp32->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 1; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 1; + params->input_tensors->data[1] = 2; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 3; + + params = interpreter2_fp32->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 3; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 5; + params->input_tensors->data[1] = 6; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 7; + + *partition_params_array = interpreter2_fp32->delegate_params(); + *num_partitions = interpreter2_fp32->num_delegate_params(); + return kTfLiteOk; + }; + + TfLiteIntArray* ops_to_replace = GetOpsToReplace( + context, /*allow_quant_ops=*/false, /*max_delegated_partitions*/ 2); + + // As the Dequant op is not pruned and the ADD op could run on GPU, we have + // 2 partitions. + EXPECT_EQ(ops_to_replace->size, 2); + // ADD at index 1. + EXPECT_EQ(1, ops_to_replace->data[0]); + // ADD at index 3. + EXPECT_EQ(3, ops_to_replace->data[1]); + + TfLiteIntArrayFree(ops_to_replace); +} + class InterpreterMultiNode : public DelegatedInterpreter { public: explicit InterpreterMultiNode(bool add_op_first = true) diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index d58c03e8877..58da8862937 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -70,17 +70,25 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); class Delegate { public: - explicit Delegate(const TfLiteGpuDelegateOptionsV2* options) { + explicit Delegate(const TfLiteGpuDelegateOptionsV2* options) + : num_delegate_kernels_(0) { options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); + if (options_.max_delegated_partitions <= 0) { + options_.max_delegated_partitions = 1; + } } TfLiteDelegate* tflite_delegate() { return &delegate_; } const TfLiteGpuDelegateOptionsV2& options() const { return options_; } - bool IsQuantOpsAllowed() { + bool IsQuantOpsAllowed() const { return options_.experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT; } + int MaxDelegatedPartitions() const { + return options_.max_delegated_partitions; + } + int num_delegate_kernels() const { return num_delegate_kernels_; } private: TfLiteDelegate delegate_ = { @@ -93,13 +101,18 @@ class Delegate { }; TfLiteGpuDelegateOptionsV2 options_; + int num_delegate_kernels_ = 0; + + friend class DelegateKernel; }; // Represent the execution of a subset of nodes on GPU. class DelegateKernel { public: - explicit DelegateKernel(const TfLiteGpuDelegateOptionsV2& options) - : options_(options) {} + explicit DelegateKernel(Delegate* delegate) : delegate_(delegate) { + ++delegate_->num_delegate_kernels_; + } + ~DelegateKernel() { --delegate_->num_delegate_kernels_; } absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { @@ -115,11 +128,11 @@ class DelegateKernel { std::unique_ptr builder; bool graph_is_destroyed; - if (options_.experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY) { + const int experimental_flags = delegate_->options().experimental_flags; + if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY) { RETURN_IF_ERROR( InitializeOpenClApi(&graph, &builder, &graph_is_destroyed)); - } else if (options_.experimental_flags & - TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY) { + } else if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY) { RETURN_IF_ERROR(InitializeOpenGlApi(&graph, &builder)); } else { // By default, we try CL first & fall back to GL if that fails. @@ -241,8 +254,7 @@ class DelegateKernel { std::vector* input_refs, std::vector* output_refs) { quant_conversion_map_.clear(); - if (options_.experimental_flags & - TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT) { + if (delegate_->IsQuantOpsAllowed()) { RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph, &quant_conversion_map_)); } else { @@ -337,22 +349,23 @@ class DelegateKernel { cl::InferenceEnvironmentProperties properties; RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); + auto delegate_options = delegate_->options(); cl::InferenceOptions options; // If is_precision_loss_allowed == -1, then just use priorities instead // of paying attention to is_precision_loss_allowed value. - if (options_.is_precision_loss_allowed == -1) { - options.priority1 = ToPriority(options_.inference_priority1); - options.priority2 = ToPriority(options_.inference_priority2); - options.priority3 = ToPriority(options_.inference_priority3); + if (delegate_options.is_precision_loss_allowed == -1) { + options.priority1 = ToPriority(delegate_options.inference_priority1); + options.priority2 = ToPriority(delegate_options.inference_priority2); + options.priority3 = ToPriority(delegate_options.inference_priority3); } else { // Users set is_precision_loss_allowed explicitly, thus use it explicitly. - if (options_.is_precision_loss_allowed == 0) { + if (delegate_options.is_precision_loss_allowed == 0) { options.priority1 = InferencePriority::MAX_PRECISION; } else { options.priority1 = InferencePriority::MIN_LATENCY; } } - options.usage = ToUsage(options_.inference_preference); + options.usage = ToUsage(delegate_options.inference_preference); *graph_is_destroyed = true; RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder( options, std::move(*graph), builder)); @@ -367,11 +380,12 @@ class DelegateKernel { gl::InferenceEnvironmentProperties properties; RETURN_IF_ERROR( NewInferenceEnvironment(env_options, &gl_environment_, &properties)); + auto delegate_options = delegate_->options(); gl::InferenceOptions options; - options.usage = ToUsage(options_.inference_preference); - options.priority1 = ToPriority(options_.inference_priority1); - options.priority2 = ToPriority(options_.inference_priority2); - options.priority3 = ToPriority(options_.inference_priority3); + options.usage = ToUsage(delegate_options.inference_preference); + options.priority1 = ToPriority(delegate_options.inference_priority1); + options.priority2 = ToPriority(delegate_options.inference_priority2); + options.priority3 = ToPriority(delegate_options.inference_priority3); RETURN_IF_ERROR(gl_environment_->NewInferenceBuilder(std::move(*graph), options, builder)); enforce_same_thread_ = true; @@ -380,9 +394,8 @@ class DelegateKernel { return absl::OkStatus(); } - // Shared across all DelegateKernel instances, passed by the Delegate - // instance. - const TfLiteGpuDelegateOptionsV2& options_; + // The Delegate instance that's shared across all DelegateKernel instances. + Delegate* const delegate_; // doesn't own the memory. std::unique_ptr cl_environment_; std::unique_ptr gl_environment_; std::unique_ptr runner_; @@ -414,7 +427,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { // Everything below should happen in prepare function call, but TFLite // for whatever reason forbids that. auto gpu_delegate_kernel = - absl::make_unique(gpu_delegate->options()); + absl::make_unique(gpu_delegate); const auto status = gpu_delegate_kernel->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", @@ -463,10 +476,15 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { "TfLiteGpuDelegateV2", // .custom_name 1, // .version }; - TfLiteIntArray* ops_to_replace = GetOpsToReplace( - context, /*allow_quant_ops=*/GetDelegate(delegate)->IsQuantOpsAllowed()); + + auto* gpu_delegate = GetDelegate(delegate); + TfLiteIntArray* ops_to_replace = + GetOpsToReplace(context, gpu_delegate->IsQuantOpsAllowed(), + gpu_delegate->MaxDelegatedPartitions()); const auto status = context->ReplaceNodeSubsetsWithDelegateKernels( context, kRegistration, ops_to_replace, delegate); + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Created %d GPU delegate kernels.", + gpu_delegate->num_delegate_kernels()); TfLiteIntArrayFree(ops_to_replace); return status; } @@ -476,15 +494,17 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { } // namespace tflite TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() { - TfLiteGpuDelegateOptionsV2 options; - // set it to -1 to detect whether it was later adjusted. - options.is_precision_loss_allowed = -1; - options.inference_preference = - TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER; - options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION; - options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO; - options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO; - options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; + TfLiteGpuDelegateOptionsV2 options = { + // set it to -1 to detect whether it was later adjusted. + .is_precision_loss_allowed = -1, + .inference_preference = + TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER, + .inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION, + .inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, + .inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, + .experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE, + .max_delegated_partitions = 1, + }; return options; } diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 95b0c3fdd2b..f03392d9a3c 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -109,6 +109,11 @@ typedef struct { // Bitmask flags. See the comments in TfLiteGpuExperimentalFlags. int64_t experimental_flags; + + // A graph could have multiple partitions that can be delegated to the GPU. + // This limits the maximum number of partitions to be delegated. By default, + // it's set to 1 in TfLiteGpuDelegateOptionsV2Default(). + int32_t max_delegated_partitions; } TfLiteGpuDelegateOptionsV2; // Populates TfLiteGpuDelegateOptionsV2 as follows: diff --git a/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc b/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc index 75476223ff5..db1f32b2282 100644 --- a/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc @@ -129,6 +129,8 @@ TfLiteDelegatePtr GpuDelegateProvider::CreateTfLiteDelegate( gpu_opts.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY; } } + gpu_opts.max_delegated_partitions = + params.Get("max_delegated_partitions"); delegate = evaluation::CreateGPUDelegate(&gpu_opts); #elif defined(REAL_IPHONE_DEVICE) TFLGpuDelegateOptions gpu_opts = {0};