From 3492a0f35503290f7603e4b5198393621487b8a3 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Fri, 13 Mar 2020 23:25:31 -0700 Subject: [PATCH] Improve GPU delegate graph partitioning logic 1. Allow offloading computation starting from intermediate nodes of a model in GPU delegate. And when there are multiple partitions, only offload the first largest partition. 2. Unify GetOpsToReplace and GetOpsToReplaceFromGraphWithDequantize into the same function. RELNOTES: [TFLite] Allow GPU acceleration starting with internal nodes PiperOrigin-RevId: 300889638 Change-Id: I09da4043532a95a9efb4f167c8af4243c3889a45 --- tensorflow/lite/BUILD | 2 + .../delegates/gpu/common/model_builder.cc | 542 +++++++++++------- .../gpu/common/model_builder_test.cc | 360 ++++++++---- tensorflow/lite/util.cc | 12 + tensorflow/lite/util.h | 3 + tensorflow/lite/util_test.cc | 27 + 6 files changed, 634 insertions(+), 312 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 63d580223f0..5e22b1fed5c 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -416,6 +416,7 @@ cc_library( hdrs = ["util.h"], copts = TFLITE_DEFAULT_COPTS + tflite_copts(), deps = [ + ":kernel_api", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", ], @@ -432,6 +433,7 @@ cc_test( deps = [ ":util", "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index b2bd6f080b6..cc108ea022b 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -20,9 +20,12 @@ limitations under the License. #include #include #include +#include #include +#include #include #include +#include #include #include @@ -36,6 +39,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context.h" +#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" @@ -2746,96 +2750,309 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); } -// A utility class to remove Dequantize op in the graph. -class DequantizeOpRemover { +Status GetNodeAndRegistration(TfLiteContext* context, int node_id, + TfLiteNode** tflite_node, + TfLiteRegistration** registration) { + if (context->GetNodeAndRegistration(context, node_id, tflite_node, + registration) != kTfLiteOk) { + return InvalidArgumentError(absl::StrCat( + "Couldn't get node and registration info for op: ", node_id)); + } + return OkStatus(); +} + +using IsNodeSupportedFn = + std::function; + +// A utility class to help model graph parition and decide the partition to be +// offloaded to GPU. +// TODO(b/151152967): move the following to lite/delegates/utils +class GraphPartitionHelper { public: - bool MayAddNode(const TfLiteNode& node, int32_t op_code, int node_id, - const TfLiteTensor* tensors) { - if (op_code == kTfLiteBuiltinDequantize && - tensors[node.inputs->data[0]].type == TfLiteType::kTfLiteFloat16) { - dequant_nodes_[node.outputs->data[0]] = {node_id, node.inputs->data[0]}; - input_tensors_.insert(node.inputs->data[0]); - return true; + GraphPartitionHelper(TfLiteContext* context, + IsNodeSupportedFn is_node_supported_fn) + : is_node_supported_fn_(is_node_supported_fn), context_(context) {} + + virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } + + // Partitions the graph into multiple subgraphs, each of which is in + // dependency order with others + virtual Status Partition(std::set* unsupported_nodes_info) { + RETURN_IF_ERROR(PrepareSupportedNodes(unsupported_nodes_info)); + + TfLiteDelegateParams* partition_params_array_ = nullptr; + int num_partitions_ = 0; + if (context_->PreviewDelegatePartitioning(context_, supported_nodes_, + &partition_params_array_, + &num_partitions_) != kTfLiteOk) { + return InvalidArgumentError("Unable to preview delegate partition."); } - return false; + + for (int i = 0; i < num_partitions_; ++i) { + partitions_.push_back(partition_params_array_ + i); + } + + return OkStatus(); } - // Remap inputs of 'node' to the inputs of the preceding dequant if there's a - // one. - // inputs_from_dequant: records the node id of a dequantize node whose - // output tensor is one of the input tensor of this 'node'. - // orig_inputs: records original input tensor ids of this node if any input is - // remapped. - void MayRemapInputTensors(TfLiteNode* node, - std::vector* inputs_from_dequant, - std::vector* orig_inputs) const { - inputs_from_dequant->clear(); - orig_inputs->clear(); - if (dequant_nodes_.empty()) return; + // Returns the first n largest partitions or all if #partitions is less than + // 'n'. Note that partitions are ranked according to the number of nodes that + // a partition has, and the returned TfLiteDelegateParams objects are *owned* + // by the TfLite runtime. + std::vector GetFirstNLargestPartitions(int n) { + const int total = num_partitions(); + // We only sort partitions according to their sizes if necessary. + if (n < total) { + partitions_.sort(CompareTwoPartitions); + } + std::vector results; + auto p_it = partitions_.begin(); + for (int i = 0; i < std::min(total, n); ++i, ++p_it) { + results.push_back(*p_it); + } + return results; + } - TfLiteIntArray* inputs = node->inputs; - orig_inputs->reserve(inputs->size); - // Fix the node's inputs (i.e. prune out the preceding dequantize node) - // in order to test if it is supported on the GPU. - for (int j = 0; j < inputs->size; ++j) { - const int input_tid = inputs->data[j]; - orig_inputs->push_back(input_tid); - const auto it = dequant_nodes_.find(input_tid); - if (it != dequant_nodes_.end()) { - inputs_from_dequant->push_back(it->second.node_id); - // Remap inputs of this node to the inputs of the preceding dequant. - inputs->data[j] = it->second.input_tensor_id; + int num_total_nodes() const { return num_total_nodes_; } + int num_partitions() const { return partitions_.size(); } + + private: + static bool CompareTwoPartitions(TfLiteDelegateParams* left, + TfLiteDelegateParams* right) { + // Reverse sort + return left->nodes_to_replace->size > right->nodes_to_replace->size; + } + + Status PrepareSupportedNodes( + std::set* unsupported_nodes_info = nullptr) { + TfLiteIntArray* execution_plan = nullptr; + if (context_->GetExecutionPlan(context_, &execution_plan) != kTfLiteOk) { + return InvalidArgumentError("Unable to get graph execution plan."); + } + + num_total_nodes_ = execution_plan->size; + supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_); + supported_nodes_->size = 0; + for (int node_id : TfLiteIntArrayView(execution_plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + auto status = + GetNodeAndRegistration(context_, node_id, &node, ®istration); + if (!status.ok()) { + supported_nodes_->size = 0; + return status; + } + + status = IsNodeSupported(context_, node, registration, node_id); + if (status.ok()) { + supported_nodes_->data[supported_nodes_->size++] = node_id; + } else if (unsupported_nodes_info) { + unsupported_nodes_info->insert( + absl::StrCat(GetOpNameByRegistration(*registration), ": ", + status.error_message())); } } + return OkStatus(); } - // May restore inputs of 'node' to 'orig_inputs' if there're inputs from - // dequant nodes (i.e. denoted by 'inputs_from_dequant'). We will also mark - // such dequantize nodes to be preserved. - void MayRestoreInputTensors(TfLiteNode* node, - const std::vector& inputs_from_dequant, - const std::vector& orig_inputs) { - if (inputs_from_dequant.empty()) return; + // The number of total nodes passed in for partition (i.e. the + // execution_plan size) + int num_total_nodes_ = 0; - for (int j = 0; j < node->inputs->size; ++j) { - node->inputs->data[j] = orig_inputs[j]; - } - // Mark those dequantize nodes to be presevered in the graph. - dequant_nodes_to_save_.insert(dequant_nodes_to_save_.end(), - inputs_from_dequant.begin(), - inputs_from_dequant.end()); + // Tells whether a node is replaceable. + const IsNodeSupportedFn is_node_supported_fn_; + TfLiteIntArray* supported_nodes_; // owns the memory + + protected: + virtual Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, + int node_id) { + return is_node_supported_fn_(context, node, registration); } - void RemovePreservedNodesFrom(std::vector* nodes) const { - for (const int nid : dequant_nodes_to_save_) { - auto it = std::find(nodes->begin(), nodes->end(), nid); - if (it != nodes->end()) nodes->erase(it); + TfLiteContext* const context_ = nullptr; + + // Doesn't own the memory of each TfLiteDelegateParams object as it's + // managed by the TfLite runtime itself. See + // TfLiteContext::PreviewDelegatePartitioning for details. + std::list partitions_; +}; + +class GraphWithDequantPartitionHelper : public GraphPartitionHelper { + public: + GraphWithDequantPartitionHelper(TfLiteContext* context, + IsNodeSupportedFn is_node_supported_fn) + : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} + + Status Partition(std::set* unsupported_nodes_info) override { + auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); + // Clean up those partitions that have a single dequant op. NoteThose + // removed dequant ops have to be reserved in the graph and should not be + // delegated. + RemoveSingleDequantNodePartitions(); + return status; + } + + // Returns a list of node indices of all nodes from the first n largest + // partitions. If there are fewer paritions than n, all nodes will be + // returned. The partition is ranked according to the number of nodes. + std::vector GetNodesOfFirstNLargestPartitions(int n) { + // We first get partitions to reduce the number of nodes to be checked in + // deciding which dequant ops could actually be replaced. And then we + // remap input-tensor to dequant nodes' inputs and remove those + // to-be-reserved dequant nodes. + auto first_nps = GetFirstNLargestPartitions(n); + std::vector ops_to_replace; + for (const auto p : first_nps) { + auto nodes = p->nodes_to_replace; + ops_to_replace.insert(ops_to_replace.end(), nodes->data, + nodes->data + nodes->size); } + RemapInputTensors(ops_to_replace); + RemoveReservedDequantsFromNodes(&ops_to_replace); + return ops_to_replace; + } + + protected: + Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, + int node_id) override { + // If we need to handle dequant nodes, we have to remap input tensors of + // this node if some of them come from a dequant node before testing if + // the node is supported. + std::vector orig_inputs; + if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, + &orig_inputs)) { + // We have a dequant op here. Note that we retrun an Ok status because a + // dequant node is first added as supported. Later, this dequant node + // will be removed if it has to be preserved in the graph which happens + // when its immediate downstream nodes cannot be supported. + return OkStatus(); + } + const auto status = GraphPartitionHelper::IsNodeSupported( + context, node, registration, node_id); + RestoreToOrigInputTensors(node, orig_inputs); + return status; } private: - struct NodeInfo { - int node_id; - int input_tensor_id; - }; + // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. + // When it's not a dequant op, remap its inputs to the inputs of the preceding + // dequant if there's a one and returns false. 'orig_inputs' records original + // input tensor ids of this node if any input is remapped. + bool RecordAndRemapInputTensors(int32_t op_code, int node_id, + TfLiteNode* node, + std::vector* orig_inputs) { + orig_inputs->clear(); + // Record the dequant node. + if (op_code == kTfLiteBuiltinDequantize && + context_->tensors[node->inputs->data[0]].type == + TfLiteType::kTfLiteFloat16) { + dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; + return true; + } + // For a dequantize op, there's no need to remap its input tensors. + if (dequant_nodes_.empty()) return false; + RemapInputTensors(node, orig_inputs); + return false; + } - // A map recording dequantize nodes of this graph. A dequantize node is - // identified by the output tensor id. The value is a NodeInfo. - std::unordered_map dequant_nodes_; + // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. + void RestoreToOrigInputTensors(TfLiteNode* node, + const std::vector& orig_inputs) { + if (node->inputs->size != orig_inputs.size()) return; + for (int j = 0; j < node->inputs->size; ++j) { + node->inputs->data[j] = orig_inputs[j]; + } + } - // A set of input tensor ids of dequantize nodes. - std::set input_tensors_; + // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of + // them are from dequant ops. + void RemapInputTensors(const std::vector& nodes) const { + for (int node_id : nodes) { + TfLiteNode* node; + TfLiteRegistration* registration; + GetNodeAndRegistration(context_, node_id, &node, ®istration) + .IgnoreError(); + RemapInputTensors(node, nullptr /* orig_inputs*/); + } + } - // The node ids of dequantize nodes that has to be preserved in the graph. - std::vector dequant_nodes_to_save_; + void RemoveSingleDequantNodePartitions() { + auto it = partitions_.begin(); + while (it != partitions_.end()) { + auto p = *it; + if (p->nodes_to_replace->size != 1) { + ++it; + continue; + } + int node_id = p->nodes_to_replace->data[0]; + TfLiteNode* node = nullptr; + TfLiteRegistration* registration = nullptr; + GetNodeAndRegistration(context_, node_id, &node, ®istration) + .IgnoreError(); + if (registration->builtin_code != kTfLiteBuiltinDequantize) { + ++it; + continue; + } + // Note such dequant nodes have to be preserved in the graph as dequant + // ops are not actually supported in the GPU delegate. + dequant_nodes_to_save_.insert(node_id); + it = partitions_.erase(it); + } + } + + void RemoveReservedDequantsFromNodes(std::vector* nodes) { + if (dequant_nodes_to_save_.empty()) return; + auto it = nodes->begin(); + while (it != nodes->end()) { + if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { + ++it; + continue; + } + it = nodes->erase(it); + } + } + + // Remap input tensors of a single 'node' if some of come from a dequant op. + // If 'orig_inputs' isn't nullptr, it records original input tensor ids of + // this node if any input is remapped. + void RemapInputTensors(TfLiteNode* node, + std::vector* orig_inputs) const { + TfLiteIntArray* inputs = node->inputs; + auto inputs_view = TfLiteIntArrayView(inputs); + // Prepopulate 'orig_inputs' first and clear it if there's no input from a + // dequant op. + if (orig_inputs) { + orig_inputs->clear(); + orig_inputs->reserve(inputs->size); + for (auto tid : inputs_view) { + orig_inputs->push_back(tid); + } + } + // Fix this node's inputs (i.e. prune out the preceding dequantize node) in + // order to test if it is supported. + bool is_remapped = false; + for (int j = 0; j < inputs->size; ++j) { + const int input_tid = inputs->data[j]; + const auto it = dequant_nodes_.find(input_tid); + if (it != dequant_nodes_.end()) { + inputs->data[j] = it->second; + is_remapped = true; + } + } + if (!is_remapped && orig_inputs) orig_inputs->clear(); + } + + // A map recording dequantize nodes's input/output tensors of this selected + // graph. The key is the output tensor id, and the value is the input tensor + // id. + std::unordered_map dequant_nodes_; + + // A set of dequant nodes as in node indices that have to be preserved in the + // graph. + std::set dequant_nodes_to_save_; }; -} // namespace - -Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, - TensorRef* tensor_ref) { - tensor_ref->type = ToDataType(tflite_tensor.type); - return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); -} Status IsSupported(const TfLiteContext* context, TfLiteNode* node, const TfLiteRegistration* registration) { @@ -2856,158 +3073,63 @@ bool IsAllFloatTensors(const TfLiteContext* context, return true; } -std::string GetOpNameByRegistration(const TfLiteRegistration* registration) { - auto op = registration->builtin_code; - std::string result = - EnumNameBuiltinOperator(static_cast(op)); - if (op == kTfLiteBuiltinCustom) { - result += " " + std::string(registration->custom_name); - } - return result; -} +} // namespace -Status GetNodeAndRegistration(TfLiteContext* context, int node_id, - TfLiteNode** tflite_node, - TfLiteRegistration** registration) { - if (context->GetNodeAndRegistration(context, node_id, tflite_node, - registration) != kTfLiteOk) { - return InvalidArgumentError(absl::StrCat( - "Couldn't get node and registration info for op: ", node_id)); - } - return OkStatus(); -} - -TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) { - TfLiteIntArray* execution_plan = nullptr; - if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { - context->ReportError(context, "Unable to get graph execution plan."); - return nullptr; - } - std::set errors; - std::vector ops_to_replace; - DequantizeOpRemover dequant_remover; - - for (int i = 0; i < execution_plan->size; ++i) { - const int node_id = execution_plan->data[i]; - TfLiteNode* node = nullptr; - TfLiteRegistration* registration = nullptr; - auto status = - GetNodeAndRegistration(context, node_id, &node, ®istration); - if (!status.ok()) { - context->ReportError(context, status.error_message().c_str()); - return nullptr; - } - - if (dequant_remover.MayAddNode(*node, registration->builtin_code, node_id, - context->tensors)) { - // For now, add the node to the list of ops to replace. - ops_to_replace.push_back(node_id); - continue; - } - - // Record the node id of a dequantize node whose output tensor is one of the - // input tensor of this node. - std::vector inputs_from_dequant; - // Original input tensor ids of this node. - std::vector orig_inputs; - dequant_remover.MayRemapInputTensors(node, &inputs_from_dequant, - &orig_inputs); - - status = IsSupported(context, node, registration); - if (status.ok() && - // TODO(eignasheva): resolve sub operation support for metal delegate - // registration->builtin_code != kTfLiteBuiltinSub && - IsAllFloatTensors(context, node->inputs) && - IsAllFloatTensors(context, node->outputs) && errors.empty()) { - // Node is supported and there were no previous errors. - ops_to_replace.push_back(node_id); - } else { - // The node is not replaceable, record an error message. - errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ", - status.error_message())); - dequant_remover.MayRestoreInputTensors(node, inputs_from_dequant, - orig_inputs); - } - } - - if (!errors.empty()) { - std::string unsupported = absl::StrJoin(errors, "\n"); - std::string error_message = - "Next operations are not supported by GPU delegate:\n" + unsupported + - "\nFirst " + std::to_string(ops_to_replace.size()) + - " operations will run on the GPU, and the remaining " + - std::to_string(execution_plan->size - ops_to_replace.size()) + - " on the CPU."; - context->ReportError(context, error_message.c_str()); - } - - dequant_remover.RemovePreservedNodesFrom(&ops_to_replace); - return ConvertVectorToTfLiteIntArray(ops_to_replace); +Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRef* tensor_ref) { + tensor_ref->type = ToDataType(tflite_tensor.type); + return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); } // TODO(impjdi): Check number of input/output tensors and their dimensions. // TODO(impjdi): Check ops' parameters. TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { - TfLiteIntArray* execution_plan = nullptr; - if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { - context->ReportError(context, "Unable to get graph execution plan."); + IsNodeSupportedFn node_supported_fn = + [=](TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration) -> Status { + RETURN_IF_ERROR(IsSupported(context, node, registration)); + return (IsAllFloatTensors(context, node->inputs) && + IsAllFloatTensors(context, node->outputs)) + ? OkStatus() + : FailedPreconditionError( + "OP is supported, but tensor type isn't matched!"); + }; + + GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn); + std::set unsupported_nodes_info; + auto status = partition_helper.Partition(&unsupported_nodes_info); + if (!status.ok()) { + TF_LITE_KERNEL_LOG(context, status.error_message().c_str()); return nullptr; } - // Dispatch to another function if graph has Dequantize nodes. - for (int i = 0; i < execution_plan->size; ++i) { - const int node_id = execution_plan->data[i]; - TfLiteNode* node = nullptr; - TfLiteRegistration* registration = nullptr; - auto status = - GetNodeAndRegistration(context, node_id, &node, ®istration); - if (!status.ok()) { - context->ReportError(context, status.error_message().c_str()); - return nullptr; - } - if (registration->builtin_code == kTfLiteBuiltinDequantize && - context->tensors[node->inputs->data[0]].type == - TfLiteType::kTfLiteFloat16) { - return GetOpsToReplaceFromGraphWithDequantize(context); - } - } + // We simply get 1st largest partition, but we could later explore whether + // getting more partitions could lead to better performance, i.e. by + // parameterizing '1' here. + std::vector ops_to_replace = + partition_helper.GetNodesOfFirstNLargestPartitions(1); - // No Dequantize nodes. Iterate through graph and find ops to replace. - TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size); - subgraph->size = 0; - std::set errors; - for (int i = 0; i < execution_plan->size; ++i) { - const int node_id = execution_plan->data[i]; - TfLiteNode* node; - TfLiteRegistration* registration; - auto status = - GetNodeAndRegistration(context, node_id, &node, ®istration); - if (!status.ok()) { - context->ReportError(context, status.error_message().c_str()); - return nullptr; - } - status = IsSupported(context, node, registration); - if (status.ok() && - // TODO(eignasheva): resolve sub operation support for metal delegate - // registration->builtin_code != kTfLiteBuiltinSub && - IsAllFloatTensors(context, node->inputs) && - IsAllFloatTensors(context, node->outputs)) { - if (errors.empty()) subgraph->data[subgraph->size++] = node_id; + if (!unsupported_nodes_info.empty()) { + std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n"); + std::string error_message = absl::StrCat( + "Following operations are not supported by GPU delegate:\n", + unsupported, "\n"); + if (!ops_to_replace.empty()) { + absl::StrAppendFormat( + &error_message, + "%d operations will run on the GPU (first node: " + "%d, last node: %d), and the remaining %d", + ops_to_replace.size(), ops_to_replace.front(), ops_to_replace.back(), + partition_helper.num_total_nodes() - ops_to_replace.size()); } else { - errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ", - status.error_message())); + absl::StrAppend(&error_message, + "No operations will run on the GPU, and all ", + partition_helper.num_total_nodes()); } + absl::StrAppend(&error_message, " operations will run on the CPU."); + TF_LITE_KERNEL_LOG(context, error_message.c_str()); } - if (!errors.empty()) { - std::string unsupported = absl::StrJoin(errors, "\n"); - std::string error_message = - "Next operations are not supported by GPU delegate:\n" + unsupported + - "\nFirst " + std::to_string(subgraph->size) + - " operations will run on the GPU, and the remaining " + - std::to_string(execution_plan->size - subgraph->size) + " on the CPU."; - context->ReportError(context, error_message.c_str()); - } - return subgraph; + return ConvertVectorToTfLiteIntArray(ops_to_replace); } Status BuildModel(TfLiteContext* context, @@ -3047,7 +3169,7 @@ Status BuildModel(TfLiteContext* context, const auto status = operations[i]->Parse(tflite_node, registration, graph, &reader); if (!status.ok()) { - return InternalError(absl::StrCat(GetOpNameByRegistration(registration), + return InternalError(absl::StrCat(GetOpNameByRegistration(*registration), ": ", status.error_message())); } } diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index d2851829a99..5cad4d186aa 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -120,9 +120,52 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) { EXPECT_FALSE(status.ok()); } -class InterpreterFp16 { +class DelegatedInterpreter { public: - explicit InterpreterFp16(TfLiteBuiltinOperator op) { + explicit DelegatedInterpreter(int num_nodes) { + exec_plan_ = TfLiteIntArrayCreate(num_nodes); + } + virtual ~DelegatedInterpreter() { + TfLiteIntArrayFree(exec_plan_); + for (auto params : delegate_params_) { + TfLiteIntArrayFree(params.nodes_to_replace); + TfLiteIntArrayFree(params.input_tensors); + TfLiteIntArrayFree(params.output_tensors); + } + } + + // Get the TfLiteContext to be mocked for swapping out functions that have to + // be called inside delegate (i.e. in delegat kernel mode). + TfLiteContext* context() { return interpreter_.primary_subgraph().context(); } + + std::vector>& + nodes_and_registration() { + return interpreter_.primary_subgraph().nodes_and_registration(); + } + + TfLiteIntArray* exec_plan() const { return exec_plan_; } + TfLiteDelegateParams* add_delegate_params() { + delegate_params_.push_back(TfLiteDelegateParams()); + return &delegate_params_.back(); + } + TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); } + int num_delegate_params() { return delegate_params_.size(); } + + protected: + Interpreter interpreter_; + + private: + // The manually-set execution plan for this delegated interpreter. + TfLiteIntArray* exec_plan_; + + // The TfLiteDelegateParams object that's manually populated inside the mocked + // TfLiteContext::PreviewDelegatePartitioning. + std::vector delegate_params_; +}; + +class InterpreterFp16 : public DelegatedInterpreter { + public: + explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) { void* builtin_data = malloc(sizeof(int)); EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk); EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk); @@ -187,22 +230,23 @@ class InterpreterFp16 { 3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false), kTfLiteOk); - exec_plan_ = TfLiteIntArrayCreate(3); - exec_plan_->data[0] = 0; - exec_plan_->data[1] = 1; - exec_plan_->data[2] = 2; + exec_plan()->data[0] = 0; + exec_plan()->data[1] = 1; + exec_plan()->data[2] = 2; } - - ~InterpreterFp16() { TfLiteIntArrayFree(exec_plan_); } - - Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } - TfLiteIntArray* exec_plan() const { return exec_plan_; } - - private: - Interpreter interpreter_; - TfLiteIntArray* exec_plan_; }; +// **NOTE**: we have several interpreter instances created at global scope to +// test *exactly* the GetOpsToReplace function alone, and not the sequence of +// function calls that includes GetOpsToReplace when calling +// ModifyGraphWithDelegate. A TfLiteContext is needed to test GetOpsToReplace, +// but TfLiteContexts intentionally make it difficult to call certain functions +// in a non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for +// details) We create our own GetExecutionPlan, GetNodeAndRegistration and +// PreviewDelegatePartitioning lambdas inside each test, but we can't use local +// captures without changing the function signature. Therefore, this test data +// lives at global scope in order to be accessible inside the lambda. + InterpreterFp16* interpreter_fp16_add_op = new InterpreterFp16(kTfLiteBuiltinAdd); @@ -218,7 +262,8 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { // t0 (FP16) --> Add -> t4 // t2 (FP16) --/ // - TfLiteContext* context = interpreter_fp16_add_op->GetSubgraph()->context(); + TfLiteContext* context = interpreter_fp16_add_op->context(); + // These functions are meant to be called inside delegates. Swap out // for similar functions to permit direct calling of GetOpsToReplace. context->GetExecutionPlan = [](struct TfLiteContext* context, @@ -229,12 +274,30 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration) { - auto& node_and_reg = interpreter_fp16_add_op->GetSubgraph() - ->nodes_and_registration()[node_index]; + auto& node_and_reg = + interpreter_fp16_add_op->nodes_and_registration()[node_index]; *node = &node_and_reg.first; *registration = &node_and_reg.second; return kTfLiteOk; }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter_fp16_add_op->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(3); + params->nodes_to_replace->data[0] = 0; + params->nodes_to_replace->data[1] = 1; + params->nodes_to_replace->data[2] = 2; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 0; + params->input_tensors->data[1] = 2; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 4; + + *partition_params_array = interpreter_fp16_add_op->delegate_params(); + *num_partitions = interpreter_fp16_add_op->num_delegate_params(); + return kTfLiteOk; + }; TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); @@ -251,17 +314,6 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { TfLiteIntArrayFree(ops_to_replace); } -// This interpreter instance is created at global scope to test *exactly* -// the GetOpsToReplace function alone, and not the sequence of function calls -// that includes GetOpsToReplace when calling ModifyGraphWithDelegate. -// A TfLiteContext is needed to test GetOpsToReplace, but TfLiteContexts -// intentionally make it difficult to call certain functions in a -// non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for details) -// We create our own GetExecutionPlan and GetNodeAndRegistration lambdas -// inside each test, but we can't use local captures without changing the -// function signature. Therefore, this test data lives at global scope -// in order to be accessible inside the lambda. - InterpreterFp16* interpreter_fp16_gt_op = new InterpreterFp16(kTfLiteBuiltinGreater); @@ -274,7 +326,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { // Because there is no GPU equivalent for the Greater op, we don't prune // the Dequantize nodes. - TfLiteContext* context = interpreter_fp16_gt_op->GetSubgraph()->context(); + TfLiteContext* context = interpreter_fp16_gt_op->context(); // These functions are meant to be called inside delegates. Swap out // for similar functions to permit direct calling of GetOpsToReplace. context->GetExecutionPlan = [](struct TfLiteContext* context, @@ -285,12 +337,37 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration) { - auto& node_and_reg = interpreter_fp16_gt_op->GetSubgraph() - ->nodes_and_registration()[node_index]; + auto& node_and_reg = + interpreter_fp16_gt_op->nodes_and_registration()[node_index]; *node = &node_and_reg.first; *registration = &node_and_reg.second; return kTfLiteOk; }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter_fp16_gt_op->add_delegate_params(); + // First partition for DequantNode (t0->t1) + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 0; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 0; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 1; + + // Second partition for DequantNode (t2->t3) + params = interpreter_fp16_add_op->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 0; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 0; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 1; + + *partition_params_array = interpreter_fp16_gt_op->delegate_params(); + *num_partitions = interpreter_fp16_gt_op->num_delegate_params(); + return kTfLiteOk; + }; TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); @@ -309,9 +386,9 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { TfLiteIntArrayFree(ops_to_replace); } -class InterpreterFp32 { +class InterpreterFp32 : public DelegatedInterpreter { public: - InterpreterFp32() { + InterpreterFp32() : DelegatedInterpreter(2) { void* builtin_data = malloc(sizeof(int)); EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk); EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk); @@ -363,34 +440,24 @@ class InterpreterFp32 { interpreter_.SetTensorParametersReadWrite( 2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false), kTfLiteOk); - exec_plan_ = TfLiteIntArrayCreate(2); - exec_plan_->data[0] = 0; - exec_plan_->data[1] = 1; + + exec_plan()->data[0] = 0; + exec_plan()->data[1] = 1; } - - ~InterpreterFp32() { TfLiteIntArrayFree(exec_plan_); } - - Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } - TfLiteIntArray* exec_plan() const { return exec_plan_; } - - private: - Interpreter interpreter_; - TfLiteIntArray* exec_plan_; }; InterpreterFp32* interpreter_fp32 = new InterpreterFp32(); TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) { - // A graph with a Dequant node with uint8 input - // is not pruned. The delegate will attempt to replace it - // with a GPU op, but this op is currently not supported on - // the GPU. Therefore, the Dequant op and all downstream ops - // will be scheduled to run on the CPU. + // A graph with a Dequant node with uint8 input is not pruned. As this op is + // currently not supported on the GPU. Therefore, the Dequant op will be + // scheduled to run on the CPU while the remaining supported op Add on the + // GPU. // // t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3 // t2 (FP32) --/ // - TfLiteContext* context = interpreter_fp32->GetSubgraph()->context(); + TfLiteContext* context = interpreter_fp32->context(); // These functions are meant to be called inside delegates. Swap out // for similar functions to permit direct calling of GetOpsToReplace. @@ -402,24 +469,43 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) { context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration) { - auto& node_and_reg = - interpreter_fp32->GetSubgraph()->nodes_and_registration()[node_index]; + auto& node_and_reg = interpreter_fp32->nodes_and_registration()[node_index]; *node = &node_and_reg.first; *registration = &node_and_reg.second; return kTfLiteOk; }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter_fp32->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 1; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 1; + params->input_tensors->data[1] = 2; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 3; + + *partition_params_array = interpreter_fp32->delegate_params(); + *num_partitions = interpreter_fp32->num_delegate_params(); + return kTfLiteOk; + }; TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); - // No ops are run on the GPU, since the Dequant op is not pruned and must run - // on the CPU. - EXPECT_EQ(ops_to_replace->size, 0); + // As the Dequant op is not pruned and the ADD op could run on GPU, we have + // 1 partition. + EXPECT_EQ(ops_to_replace->size, 1); + // ADD at index 1. + EXPECT_EQ(1, ops_to_replace->data[0]); + TfLiteIntArrayFree(ops_to_replace); } -class InterpreterMultiNode { +class InterpreterMultiNode : public DelegatedInterpreter { public: - explicit InterpreterMultiNode(bool add_op_first = true) { + explicit InterpreterMultiNode(bool add_op_first = true) + : DelegatedInterpreter(5) { void* builtin_data = malloc(sizeof(int)); EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk); EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk); @@ -552,38 +638,29 @@ class InterpreterMultiNode { interpreter_.SetTensorParametersReadWrite( 7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false), kTfLiteOk); - exec_plan_ = TfLiteIntArrayCreate(5); - exec_plan_->data[0] = 0; - exec_plan_->data[1] = 1; - exec_plan_->data[2] = 2; - exec_plan_->data[3] = 3; - exec_plan_->data[4] = 4; + + exec_plan()->data[0] = 0; + exec_plan()->data[1] = 1; + exec_plan()->data[2] = 2; + exec_plan()->data[3] = 3; + exec_plan()->data[4] = 4; } - - ~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); } - - Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } - TfLiteIntArray* exec_plan() const { return exec_plan_; } - - private: - Interpreter interpreter_; - TfLiteIntArray* exec_plan_; }; InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode(); -TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { +TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. // 'Add' can be replaced by the GPU delegate, but 'Greater' can not. - // t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6 - // t1 (FP16) --> Dequant --> t4 (FP32) --/ - // --\ - // t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7 + // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6 + // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ + // --\ + // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7 // // OpsToReplace should replace the 'Add' op and the Dequant outputing // t5, but leave the other Dequant nodes because 'Greater' must run // on the CPU. - TfLiteContext* context = interpreter_mn->GetSubgraph()->context(); + TfLiteContext* context = interpreter_mn->context(); // These functions are meant to be called inside delegates. Swap out // for similar functions to permit direct calling of GetOpsToReplace. @@ -595,12 +672,48 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration) { - auto& node_and_reg = - interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index]; + auto& node_and_reg = interpreter_mn->nodes_and_registration()[node_index]; *node = &node_and_reg.first; *registration = &node_and_reg.second; return kTfLiteOk; }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter_mn->add_delegate_params(); + // First partition for DequantNode(0) + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 0; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 0; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 3; + + // Second partition for DequantNode(1) + params = interpreter_mn->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 1; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 1; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 4; + + // Third partition for DequantNode(1), DequantNode(2), ADD(3) + params = interpreter_mn->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(3); + params->nodes_to_replace->data[0] = 1; + params->nodes_to_replace->data[1] = 2; + params->nodes_to_replace->data[2] = 3; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 1; + params->input_tensors->data[0] = 3; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 7; + + *partition_params_array = interpreter_mn->delegate_params(); + *num_partitions = interpreter_mn->num_delegate_params(); + return kTfLiteOk; + }; TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); @@ -624,19 +737,22 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { InterpreterMultiNode* interpreter_mn2 = new InterpreterMultiNode(/*add_op_first=*/false); - -TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) { - // A graph with three Dequant nodes feeding two ops, 'Greater' and 'Add'. +TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { + // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. // 'Add' can be replaced by the GPU delegate, but 'Greater' can not. - // t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6 - // t1 (FP16) --> Dequant --> t4 (FP32) --/ - // --\ - // t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7 + // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6 + // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ + // --\ + // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7 // - // 'Greater' comes first in the execution plan though, so Add should not - // be scheduled to run on the Gpu. Further, it's inputs should remain t4 - // and t5. - TfLiteContext* context = interpreter_mn->GetSubgraph()->context(); + // Note: the graph dependency is exactly same w/ that in + // GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported + // 'Greater' op appears first in the execution plan. Despite this, + // OpsToReplace should still replace the 'Add' op and the Dequant outputing + // t5, but leave the other Dequant nodes because 'Greater' must run + // on the CPU. + + TfLiteContext* context = interpreter_mn2->context(); // These functions are meant to be called inside delegates. Swap out // for similar functions to permit direct calling of GetOpsToReplace. @@ -648,27 +764,67 @@ TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) { context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration) { - auto& node_and_reg = - interpreter_mn2->GetSubgraph()->nodes_and_registration()[node_index]; + auto& node_and_reg = interpreter_mn2->nodes_and_registration()[node_index]; *node = &node_and_reg.first; *registration = &node_and_reg.second; return kTfLiteOk; }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + auto params = interpreter_mn2->add_delegate_params(); + // First partition for DequantNode(0) + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 0; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 0; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 3; + + // Second partition for DequantNode(1) + params = interpreter_mn2->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 1; + params->input_tensors = TfLiteIntArrayCreate(1); + params->input_tensors->data[0] = 1; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 4; + + // Third partition for DequantNode(1), DequantNode(2), ADD(4) + params = interpreter_mn2->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(3); + params->nodes_to_replace->data[0] = 1; + params->nodes_to_replace->data[1] = 2; + params->nodes_to_replace->data[2] = 4; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 1; + params->input_tensors->data[0] = 3; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 7; + + *partition_params_array = interpreter_mn2->delegate_params(); + *num_partitions = interpreter_mn2->num_delegate_params(); + return kTfLiteOk; + }; + TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); - // Verify that no ops will be replaced. - EXPECT_EQ(ops_to_replace->size, 0); + EXPECT_EQ(ops_to_replace->size, 2); + // Op at index 2 is the Dequant op (t3 -> t5). + EXPECT_EQ(ops_to_replace->data[0], 2); + // Op at index 4 is the Add op. + EXPECT_EQ(ops_to_replace->data[1], 4); TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - // Verify that Add op has fp32 inputs. - context->GetNodeAndRegistration(context, 4, &node, ®istration); - EXPECT_EQ(registration->builtin_code, kTfLiteBuiltinAdd); + // Verify that Add op has fp16 inputs. + context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node, + ®istration); EXPECT_EQ(context->tensors[node->inputs->data[0]].type, - TfLiteType::kTfLiteFloat32); + TfLiteType::kTfLiteFloat16); EXPECT_EQ(context->tensors[node->inputs->data[1]].type, - TfLiteType::kTfLiteFloat32); + TfLiteType::kTfLiteFloat16); TfLiteIntArrayFree(ops_to_replace); } diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index a876eebd639..335c6773039 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -131,4 +132,15 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) { registration.invoke == &UnresolvedOpInvoke; } +std::string GetOpNameByRegistration(const TfLiteRegistration& registration) { + auto op = registration.builtin_code; + std::string result = + EnumNameBuiltinOperator(static_cast(op)); + if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) && + registration.custom_name) { + result += " " + std::string(registration.custom_name); + } + return result; +} + } // namespace tflite diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h index 42ce0deef96..3b042eb5986 100644 --- a/tensorflow/lite/util.h +++ b/tensorflow/lite/util.h @@ -21,6 +21,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_UTIL_H_ #define TENSORFLOW_LITE_UTIL_H_ +#include #include #include "tensorflow/lite/c/common.h" @@ -73,6 +74,8 @@ TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name); // Checks whether the provided op is an unresolved custom op. bool IsUnresolvedCustomOp(const TfLiteRegistration& registration); +// Returns a descriptive name with the given op TfLiteRegistration. +std::string GetOpNameByRegistration(const TfLiteRegistration& registration); } // namespace tflite #endif // TENSORFLOW_LITE_UTIL_H_ diff --git a/tensorflow/lite/util_test.cc b/tensorflow/lite/util_test.cc index b1886a7e8f5..e282431284b 100644 --- a/tensorflow/lite/util_test.cc +++ b/tensorflow/lite/util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { @@ -90,6 +91,32 @@ TEST(CombineHashes, TestHashOutputsDifferent) { EXPECT_NE(output1, output2); } +TEST(GetOpNameByRegistration, ValidBuiltinCode) { + TfLiteRegistration registration; + registration.builtin_code = tflite::BuiltinOperator_ADD; + const auto op_name = GetOpNameByRegistration(registration); + EXPECT_EQ("ADD", op_name); +} + +TEST(GetOpNameByRegistration, InvalidBuiltinCode) { + TfLiteRegistration registration; + registration.builtin_code = -1; + const auto op_name = GetOpNameByRegistration(registration); + EXPECT_EQ("", op_name); +} + +TEST(GetOpNameByRegistration, CustomName) { + TfLiteRegistration registration; + registration.builtin_code = tflite::BuiltinOperator_CUSTOM; + registration.custom_name = "TestOp"; + auto op_name = GetOpNameByRegistration(registration); + EXPECT_EQ("CUSTOM TestOp", op_name); + + registration.builtin_code = tflite::BuiltinOperator_DELEGATE; + registration.custom_name = "TestDelegate"; + op_name = GetOpNameByRegistration(registration); + EXPECT_EQ("DELEGATE TestDelegate", op_name); +} } // namespace } // namespace tflite