Improve GPU delegate graph partitioning logic
1. Allow offloading computation starting from intermediate nodes of a model in GPU delegate. And when there are multiple partitions, only offload the first largest partition. 2. Unify GetOpsToReplace and GetOpsToReplaceFromGraphWithDequantize into the same function. RELNOTES: [TFLite] Allow GPU acceleration starting with internal nodes PiperOrigin-RevId: 300889638 Change-Id: I09da4043532a95a9efb4f167c8af4243c3889a45
This commit is contained in:
		
							parent
							
								
									01641aee30
								
							
						
					
					
						commit
						3492a0f355
					
				| @ -416,6 +416,7 @@ cc_library( | ||||
|     hdrs = ["util.h"], | ||||
|     copts = TFLITE_DEFAULT_COPTS + tflite_copts(), | ||||
|     deps = [ | ||||
|         ":kernel_api", | ||||
|         "//tensorflow/lite/c:common", | ||||
|         "//tensorflow/lite/schema:schema_fbs", | ||||
|     ], | ||||
| @ -432,6 +433,7 @@ cc_test( | ||||
|     deps = [ | ||||
|         ":util", | ||||
|         "//tensorflow/lite/c:common", | ||||
|         "//tensorflow/lite/schema:schema_fbs", | ||||
|         "@com_google_googletest//:gtest", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -20,9 +20,12 @@ limitations under the License. | ||||
| #include <algorithm> | ||||
| #include <cstdint> | ||||
| #include <cstring> | ||||
| #include <list> | ||||
| #include <memory> | ||||
| #include <set> | ||||
| #include <string> | ||||
| #include <unordered_map> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include <fp16.h> | ||||
| @ -36,6 +39,7 @@ limitations under the License. | ||||
| #include "tensorflow/lite/c/builtin_op_data.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/context.h" | ||||
| #include "tensorflow/lite/context_util.h" | ||||
| #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" | ||||
| #include "tensorflow/lite/delegates/gpu/common/data_type.h" | ||||
| #include "tensorflow/lite/delegates/gpu/common/model.h" | ||||
| @ -2746,96 +2750,309 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser( | ||||
|   return absl::make_unique<UnsupportedOperationParser>(); | ||||
| } | ||||
| 
 | ||||
| // A utility class to remove Dequantize op in the graph.
 | ||||
| class DequantizeOpRemover { | ||||
| Status GetNodeAndRegistration(TfLiteContext* context, int node_id, | ||||
|                               TfLiteNode** tflite_node, | ||||
|                               TfLiteRegistration** registration) { | ||||
|   if (context->GetNodeAndRegistration(context, node_id, tflite_node, | ||||
|                                       registration) != kTfLiteOk) { | ||||
|     return InvalidArgumentError(absl::StrCat( | ||||
|         "Couldn't get node and registration info for op: ", node_id)); | ||||
|   } | ||||
|   return OkStatus(); | ||||
| } | ||||
| 
 | ||||
| using IsNodeSupportedFn = | ||||
|     std::function<Status(TfLiteContext*, TfLiteNode*, TfLiteRegistration*)>; | ||||
| 
 | ||||
| // A utility class to help model graph parition and decide the partition to be
 | ||||
| // offloaded to GPU.
 | ||||
| // TODO(b/151152967): move the following to lite/delegates/utils
 | ||||
| class GraphPartitionHelper { | ||||
|  public: | ||||
|   bool MayAddNode(const TfLiteNode& node, int32_t op_code, int node_id, | ||||
|                   const TfLiteTensor* tensors) { | ||||
|     if (op_code == kTfLiteBuiltinDequantize && | ||||
|         tensors[node.inputs->data[0]].type == TfLiteType::kTfLiteFloat16) { | ||||
|       dequant_nodes_[node.outputs->data[0]] = {node_id, node.inputs->data[0]}; | ||||
|       input_tensors_.insert(node.inputs->data[0]); | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
|   GraphPartitionHelper(TfLiteContext* context, | ||||
|                        IsNodeSupportedFn is_node_supported_fn) | ||||
|       : is_node_supported_fn_(is_node_supported_fn), context_(context) {} | ||||
| 
 | ||||
|   virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } | ||||
| 
 | ||||
|   // Partitions the graph into multiple subgraphs, each of which is in
 | ||||
|   // dependency order with others
 | ||||
|   virtual Status Partition(std::set<std::string>* unsupported_nodes_info) { | ||||
|     RETURN_IF_ERROR(PrepareSupportedNodes(unsupported_nodes_info)); | ||||
| 
 | ||||
|     TfLiteDelegateParams* partition_params_array_ = nullptr; | ||||
|     int num_partitions_ = 0; | ||||
|     if (context_->PreviewDelegatePartitioning(context_, supported_nodes_, | ||||
|                                               &partition_params_array_, | ||||
|                                               &num_partitions_) != kTfLiteOk) { | ||||
|       return InvalidArgumentError("Unable to preview delegate partition."); | ||||
|     } | ||||
| 
 | ||||
|   // Remap inputs of 'node' to the inputs of the preceding dequant if there's a
 | ||||
|   // one.
 | ||||
|   // inputs_from_dequant: records the node id of a dequantize node whose
 | ||||
|   // output tensor is one of the input tensor of this 'node'.
 | ||||
|   // orig_inputs: records original input tensor ids of this node if any input is
 | ||||
|   // remapped.
 | ||||
|   void MayRemapInputTensors(TfLiteNode* node, | ||||
|                             std::vector<int>* inputs_from_dequant, | ||||
|                             std::vector<int>* orig_inputs) const { | ||||
|     inputs_from_dequant->clear(); | ||||
|     orig_inputs->clear(); | ||||
|     if (dequant_nodes_.empty()) return; | ||||
| 
 | ||||
|     TfLiteIntArray* inputs = node->inputs; | ||||
|     orig_inputs->reserve(inputs->size); | ||||
|     // Fix the node's inputs (i.e. prune out the preceding dequantize node)
 | ||||
|     // in order to test if it is supported on the GPU.
 | ||||
|     for (int j = 0; j < inputs->size; ++j) { | ||||
|       const int input_tid = inputs->data[j]; | ||||
|       orig_inputs->push_back(input_tid); | ||||
|       const auto it = dequant_nodes_.find(input_tid); | ||||
|       if (it != dequant_nodes_.end()) { | ||||
|         inputs_from_dequant->push_back(it->second.node_id); | ||||
|         // Remap inputs of this node to the inputs of the preceding dequant.
 | ||||
|         inputs->data[j] = it->second.input_tensor_id; | ||||
|       } | ||||
|     } | ||||
|     for (int i = 0; i < num_partitions_; ++i) { | ||||
|       partitions_.push_back(partition_params_array_ + i); | ||||
|     } | ||||
| 
 | ||||
|   // May restore inputs of 'node' to 'orig_inputs' if there're inputs from
 | ||||
|   // dequant nodes (i.e. denoted by 'inputs_from_dequant'). We will also mark
 | ||||
|   // such dequantize nodes to be preserved.
 | ||||
|   void MayRestoreInputTensors(TfLiteNode* node, | ||||
|                               const std::vector<int>& inputs_from_dequant, | ||||
|                               const std::vector<int>& orig_inputs) { | ||||
|     if (inputs_from_dequant.empty()) return; | ||||
| 
 | ||||
|     for (int j = 0; j < node->inputs->size; ++j) { | ||||
|       node->inputs->data[j] = orig_inputs[j]; | ||||
|     } | ||||
|     // Mark those dequantize nodes to be presevered in the graph.
 | ||||
|     dequant_nodes_to_save_.insert(dequant_nodes_to_save_.end(), | ||||
|                                   inputs_from_dequant.begin(), | ||||
|                                   inputs_from_dequant.end()); | ||||
|     return OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|   void RemovePreservedNodesFrom(std::vector<int>* nodes) const { | ||||
|     for (const int nid : dequant_nodes_to_save_) { | ||||
|       auto it = std::find(nodes->begin(), nodes->end(), nid); | ||||
|       if (it != nodes->end()) nodes->erase(it); | ||||
|   // Returns the first n largest partitions or all if #partitions is less than
 | ||||
|   // 'n'. Note that partitions are ranked according to the number of nodes that
 | ||||
|   // a partition has, and the returned TfLiteDelegateParams objects are *owned*
 | ||||
|   // by the TfLite runtime.
 | ||||
|   std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(int n) { | ||||
|     const int total = num_partitions(); | ||||
|     // We only sort partitions according to their sizes if necessary.
 | ||||
|     if (n < total) { | ||||
|       partitions_.sort(CompareTwoPartitions); | ||||
|     } | ||||
|     std::vector<TfLiteDelegateParams*> results; | ||||
|     auto p_it = partitions_.begin(); | ||||
|     for (int i = 0; i < std::min(total, n); ++i, ++p_it) { | ||||
|       results.push_back(*p_it); | ||||
|     } | ||||
|     return results; | ||||
|   } | ||||
| 
 | ||||
|   int num_total_nodes() const { return num_total_nodes_; } | ||||
|   int num_partitions() const { return partitions_.size(); } | ||||
| 
 | ||||
|  private: | ||||
|   static bool CompareTwoPartitions(TfLiteDelegateParams* left, | ||||
|                                    TfLiteDelegateParams* right) { | ||||
|     // Reverse sort
 | ||||
|     return left->nodes_to_replace->size > right->nodes_to_replace->size; | ||||
|   } | ||||
| 
 | ||||
|   Status PrepareSupportedNodes( | ||||
|       std::set<std::string>* unsupported_nodes_info = nullptr) { | ||||
|     TfLiteIntArray* execution_plan = nullptr; | ||||
|     if (context_->GetExecutionPlan(context_, &execution_plan) != kTfLiteOk) { | ||||
|       return InvalidArgumentError("Unable to get graph execution plan."); | ||||
|     } | ||||
| 
 | ||||
|     num_total_nodes_ = execution_plan->size; | ||||
|     supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_); | ||||
|     supported_nodes_->size = 0; | ||||
|     for (int node_id : TfLiteIntArrayView(execution_plan)) { | ||||
|       TfLiteNode* node; | ||||
|       TfLiteRegistration* registration; | ||||
|       auto status = | ||||
|           GetNodeAndRegistration(context_, node_id, &node, ®istration); | ||||
|       if (!status.ok()) { | ||||
|         supported_nodes_->size = 0; | ||||
|         return status; | ||||
|       } | ||||
| 
 | ||||
|       status = IsNodeSupported(context_, node, registration, node_id); | ||||
|       if (status.ok()) { | ||||
|         supported_nodes_->data[supported_nodes_->size++] = node_id; | ||||
|       } else if (unsupported_nodes_info) { | ||||
|         unsupported_nodes_info->insert( | ||||
|             absl::StrCat(GetOpNameByRegistration(*registration), ": ", | ||||
|                          status.error_message())); | ||||
|       } | ||||
|     } | ||||
|     return OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|   // The number of total nodes passed in for partition (i.e. the
 | ||||
|   // execution_plan size)
 | ||||
|   int num_total_nodes_ = 0; | ||||
| 
 | ||||
|   // Tells whether a node is replaceable.
 | ||||
|   const IsNodeSupportedFn is_node_supported_fn_; | ||||
|   TfLiteIntArray* supported_nodes_;  // owns the memory
 | ||||
| 
 | ||||
|  protected: | ||||
|   virtual Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, | ||||
|                                  TfLiteRegistration* registration, | ||||
|                                  int node_id) { | ||||
|     return is_node_supported_fn_(context, node, registration); | ||||
|   } | ||||
| 
 | ||||
|   TfLiteContext* const context_ = nullptr; | ||||
| 
 | ||||
|   // Doesn't own the memory of each TfLiteDelegateParams object as it's
 | ||||
|   // managed by the TfLite runtime itself. See
 | ||||
|   // TfLiteContext::PreviewDelegatePartitioning for details.
 | ||||
|   std::list<TfLiteDelegateParams*> partitions_; | ||||
| }; | ||||
| 
 | ||||
| class GraphWithDequantPartitionHelper : public GraphPartitionHelper { | ||||
|  public: | ||||
|   GraphWithDequantPartitionHelper(TfLiteContext* context, | ||||
|                                   IsNodeSupportedFn is_node_supported_fn) | ||||
|       : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} | ||||
| 
 | ||||
|   Status Partition(std::set<std::string>* unsupported_nodes_info) override { | ||||
|     auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); | ||||
|     // Clean up those partitions that have a single dequant op. NoteThose
 | ||||
|     // removed dequant ops have to be reserved in the graph and should not be
 | ||||
|     // delegated.
 | ||||
|     RemoveSingleDequantNodePartitions(); | ||||
|     return status; | ||||
|   } | ||||
| 
 | ||||
|   // Returns a list of node indices of all nodes from the first n largest
 | ||||
|   // partitions. If there are fewer paritions than n, all nodes will be
 | ||||
|   // returned. The partition is ranked according to the number of nodes.
 | ||||
|   std::vector<int> GetNodesOfFirstNLargestPartitions(int n) { | ||||
|     // We first get partitions to reduce the number of nodes to be checked in
 | ||||
|     // deciding which dequant ops could actually be replaced. And then we
 | ||||
|     // remap input-tensor to dequant nodes' inputs and remove those
 | ||||
|     // to-be-reserved dequant nodes.
 | ||||
|     auto first_nps = GetFirstNLargestPartitions(n); | ||||
|     std::vector<int> ops_to_replace; | ||||
|     for (const auto p : first_nps) { | ||||
|       auto nodes = p->nodes_to_replace; | ||||
|       ops_to_replace.insert(ops_to_replace.end(), nodes->data, | ||||
|                             nodes->data + nodes->size); | ||||
|     } | ||||
|     RemapInputTensors(ops_to_replace); | ||||
|     RemoveReservedDequantsFromNodes(&ops_to_replace); | ||||
|     return ops_to_replace; | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node, | ||||
|                          TfLiteRegistration* registration, | ||||
|                          int node_id) override { | ||||
|     // If we need to handle dequant nodes, we have to remap input tensors of
 | ||||
|     // this node if some of them come from a dequant node before testing if
 | ||||
|     // the node is supported.
 | ||||
|     std::vector<int> orig_inputs; | ||||
|     if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, | ||||
|                                    &orig_inputs)) { | ||||
|       // We have a dequant op here. Note that we retrun an Ok status because a
 | ||||
|       // dequant node is first added as supported. Later, this dequant node
 | ||||
|       // will be removed if it has to be preserved in the graph which happens
 | ||||
|       // when its immediate downstream nodes cannot be supported.
 | ||||
|       return OkStatus(); | ||||
|     } | ||||
|     const auto status = GraphPartitionHelper::IsNodeSupported( | ||||
|         context, node, registration, node_id); | ||||
|     RestoreToOrigInputTensors(node, orig_inputs); | ||||
|     return status; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   struct NodeInfo { | ||||
|     int node_id; | ||||
|     int input_tensor_id; | ||||
|   }; | ||||
| 
 | ||||
|   // A map recording dequantize nodes of this graph. A dequantize node is
 | ||||
|   // identified by the output tensor id. The value is a NodeInfo.
 | ||||
|   std::unordered_map<int, NodeInfo> dequant_nodes_; | ||||
| 
 | ||||
|   // A set of input tensor ids of dequantize nodes.
 | ||||
|   std::set<int> input_tensors_; | ||||
| 
 | ||||
|   // The node ids of dequantize nodes that has to be preserved in the graph.
 | ||||
|   std::vector<int> dequant_nodes_to_save_; | ||||
| }; | ||||
| }  // namespace
 | ||||
| 
 | ||||
| Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, | ||||
|                                       TensorRef<BHWC>* tensor_ref) { | ||||
|   tensor_ref->type = ToDataType(tflite_tensor.type); | ||||
|   return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); | ||||
|   // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
 | ||||
|   // When it's not a dequant op, remap its inputs to the inputs of the preceding
 | ||||
|   // dequant if there's a one and returns false. 'orig_inputs' records original
 | ||||
|   // input tensor ids of this node if any input is remapped.
 | ||||
|   bool RecordAndRemapInputTensors(int32_t op_code, int node_id, | ||||
|                                   TfLiteNode* node, | ||||
|                                   std::vector<int>* orig_inputs) { | ||||
|     orig_inputs->clear(); | ||||
|     // Record the dequant node.
 | ||||
|     if (op_code == kTfLiteBuiltinDequantize && | ||||
|         context_->tensors[node->inputs->data[0]].type == | ||||
|             TfLiteType::kTfLiteFloat16) { | ||||
|       dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; | ||||
|       return true; | ||||
|     } | ||||
|     // For a dequantize op, there's no need to remap its input tensors.
 | ||||
|     if (dequant_nodes_.empty()) return false; | ||||
|     RemapInputTensors(node, orig_inputs); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   // Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
 | ||||
|   void RestoreToOrigInputTensors(TfLiteNode* node, | ||||
|                                  const std::vector<int>& orig_inputs) { | ||||
|     if (node->inputs->size != orig_inputs.size()) return; | ||||
|     for (int j = 0; j < node->inputs->size; ++j) { | ||||
|       node->inputs->data[j] = orig_inputs[j]; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
 | ||||
|   // them are from dequant ops.
 | ||||
|   void RemapInputTensors(const std::vector<int>& nodes) const { | ||||
|     for (int node_id : nodes) { | ||||
|       TfLiteNode* node; | ||||
|       TfLiteRegistration* registration; | ||||
|       GetNodeAndRegistration(context_, node_id, &node, ®istration) | ||||
|           .IgnoreError(); | ||||
|       RemapInputTensors(node, nullptr /* orig_inputs*/); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void RemoveSingleDequantNodePartitions() { | ||||
|     auto it = partitions_.begin(); | ||||
|     while (it != partitions_.end()) { | ||||
|       auto p = *it; | ||||
|       if (p->nodes_to_replace->size != 1) { | ||||
|         ++it; | ||||
|         continue; | ||||
|       } | ||||
|       int node_id = p->nodes_to_replace->data[0]; | ||||
|       TfLiteNode* node = nullptr; | ||||
|       TfLiteRegistration* registration = nullptr; | ||||
|       GetNodeAndRegistration(context_, node_id, &node, ®istration) | ||||
|           .IgnoreError(); | ||||
|       if (registration->builtin_code != kTfLiteBuiltinDequantize) { | ||||
|         ++it; | ||||
|         continue; | ||||
|       } | ||||
|       // Note such dequant nodes have to be preserved in the graph as dequant
 | ||||
|       // ops are not actually supported in the GPU delegate.
 | ||||
|       dequant_nodes_to_save_.insert(node_id); | ||||
|       it = partitions_.erase(it); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void RemoveReservedDequantsFromNodes(std::vector<int>* nodes) { | ||||
|     if (dequant_nodes_to_save_.empty()) return; | ||||
|     auto it = nodes->begin(); | ||||
|     while (it != nodes->end()) { | ||||
|       if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { | ||||
|         ++it; | ||||
|         continue; | ||||
|       } | ||||
|       it = nodes->erase(it); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Remap input tensors of a single 'node' if some of come from a dequant op.
 | ||||
|   // If 'orig_inputs' isn't nullptr, it records original input tensor ids of
 | ||||
|   // this node if any input is remapped.
 | ||||
|   void RemapInputTensors(TfLiteNode* node, | ||||
|                          std::vector<int>* orig_inputs) const { | ||||
|     TfLiteIntArray* inputs = node->inputs; | ||||
|     auto inputs_view = TfLiteIntArrayView(inputs); | ||||
|     // Prepopulate 'orig_inputs' first and clear it if there's no input from a
 | ||||
|     // dequant op.
 | ||||
|     if (orig_inputs) { | ||||
|       orig_inputs->clear(); | ||||
|       orig_inputs->reserve(inputs->size); | ||||
|       for (auto tid : inputs_view) { | ||||
|         orig_inputs->push_back(tid); | ||||
|       } | ||||
|     } | ||||
|     // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
 | ||||
|     // order to test if it is supported.
 | ||||
|     bool is_remapped = false; | ||||
|     for (int j = 0; j < inputs->size; ++j) { | ||||
|       const int input_tid = inputs->data[j]; | ||||
|       const auto it = dequant_nodes_.find(input_tid); | ||||
|       if (it != dequant_nodes_.end()) { | ||||
|         inputs->data[j] = it->second; | ||||
|         is_remapped = true; | ||||
|       } | ||||
|     } | ||||
|     if (!is_remapped && orig_inputs) orig_inputs->clear(); | ||||
|   } | ||||
| 
 | ||||
|   // A map recording dequantize nodes's input/output tensors of this selected
 | ||||
|   // graph. The key is the output tensor id, and the value is the input tensor
 | ||||
|   // id.
 | ||||
|   std::unordered_map<int, int> dequant_nodes_; | ||||
| 
 | ||||
|   // A set of dequant nodes as in node indices that have to be preserved in the
 | ||||
|   // graph.
 | ||||
|   std::set<int> dequant_nodes_to_save_; | ||||
| }; | ||||
| 
 | ||||
| Status IsSupported(const TfLiteContext* context, TfLiteNode* node, | ||||
|                    const TfLiteRegistration* registration) { | ||||
| @ -2856,158 +3073,63 @@ bool IsAllFloatTensors(const TfLiteContext* context, | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| std::string GetOpNameByRegistration(const TfLiteRegistration* registration) { | ||||
|   auto op = registration->builtin_code; | ||||
|   std::string result = | ||||
|       EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op)); | ||||
|   if (op == kTfLiteBuiltinCustom) { | ||||
|     result += " " + std::string(registration->custom_name); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| Status GetNodeAndRegistration(TfLiteContext* context, int node_id, | ||||
|                               TfLiteNode** tflite_node, | ||||
|                               TfLiteRegistration** registration) { | ||||
|   if (context->GetNodeAndRegistration(context, node_id, tflite_node, | ||||
|                                       registration) != kTfLiteOk) { | ||||
|     return InvalidArgumentError(absl::StrCat( | ||||
|         "Couldn't get node and registration info for op: ", node_id)); | ||||
|   } | ||||
|   return OkStatus(); | ||||
| } | ||||
| 
 | ||||
| TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) { | ||||
|   TfLiteIntArray* execution_plan = nullptr; | ||||
|   if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { | ||||
|     context->ReportError(context, "Unable to get graph execution plan."); | ||||
|     return nullptr; | ||||
|   } | ||||
|   std::set<std::string> errors; | ||||
|   std::vector<int> ops_to_replace; | ||||
|   DequantizeOpRemover dequant_remover; | ||||
| 
 | ||||
|   for (int i = 0; i < execution_plan->size; ++i) { | ||||
|     const int node_id = execution_plan->data[i]; | ||||
|     TfLiteNode* node = nullptr; | ||||
|     TfLiteRegistration* registration = nullptr; | ||||
|     auto status = | ||||
|         GetNodeAndRegistration(context, node_id, &node, ®istration); | ||||
|     if (!status.ok()) { | ||||
|       context->ReportError(context, status.error_message().c_str()); | ||||
|       return nullptr; | ||||
|     } | ||||
| 
 | ||||
|     if (dequant_remover.MayAddNode(*node, registration->builtin_code, node_id, | ||||
|                                    context->tensors)) { | ||||
|       // For now, add the node to the list of ops to replace.
 | ||||
|       ops_to_replace.push_back(node_id); | ||||
|       continue; | ||||
|     } | ||||
| 
 | ||||
|     // Record the node id of a dequantize node whose output tensor is one of the
 | ||||
|     // input tensor of this node.
 | ||||
|     std::vector<int> inputs_from_dequant; | ||||
|     // Original input tensor ids of this node.
 | ||||
|     std::vector<int> orig_inputs; | ||||
|     dequant_remover.MayRemapInputTensors(node, &inputs_from_dequant, | ||||
|                                          &orig_inputs); | ||||
| 
 | ||||
|     status = IsSupported(context, node, registration); | ||||
|     if (status.ok() && | ||||
|         // TODO(eignasheva): resolve sub operation support for metal delegate
 | ||||
|         // registration->builtin_code != kTfLiteBuiltinSub &&
 | ||||
|         IsAllFloatTensors(context, node->inputs) && | ||||
|         IsAllFloatTensors(context, node->outputs) && errors.empty()) { | ||||
|       // Node is supported and there were no previous errors.
 | ||||
|       ops_to_replace.push_back(node_id); | ||||
|     } else { | ||||
|       // The node is not replaceable, record an error message.
 | ||||
|       errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ", | ||||
|                                  status.error_message())); | ||||
|       dequant_remover.MayRestoreInputTensors(node, inputs_from_dequant, | ||||
|                                              orig_inputs); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   if (!errors.empty()) { | ||||
|     std::string unsupported = absl::StrJoin(errors, "\n"); | ||||
|     std::string error_message = | ||||
|         "Next operations are not supported by GPU delegate:\n" + unsupported + | ||||
|         "\nFirst " + std::to_string(ops_to_replace.size()) + | ||||
|         " operations will run on the GPU, and the remaining " + | ||||
|         std::to_string(execution_plan->size - ops_to_replace.size()) + | ||||
|         " on the CPU."; | ||||
|     context->ReportError(context, error_message.c_str()); | ||||
|   } | ||||
| 
 | ||||
|   dequant_remover.RemovePreservedNodesFrom(&ops_to_replace); | ||||
|   return ConvertVectorToTfLiteIntArray(ops_to_replace); | ||||
| Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, | ||||
|                                       TensorRef<BHWC>* tensor_ref) { | ||||
|   tensor_ref->type = ToDataType(tflite_tensor.type); | ||||
|   return ExtractTensorShape(tflite_tensor, &tensor_ref->shape); | ||||
| } | ||||
| 
 | ||||
| // TODO(impjdi): Check number of input/output tensors and their dimensions.
 | ||||
| // TODO(impjdi): Check ops' parameters.
 | ||||
| TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { | ||||
|   TfLiteIntArray* execution_plan = nullptr; | ||||
|   if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { | ||||
|     context->ReportError(context, "Unable to get graph execution plan."); | ||||
|   IsNodeSupportedFn node_supported_fn = | ||||
|       [=](TfLiteContext* context, TfLiteNode* node, | ||||
|           TfLiteRegistration* registration) -> Status { | ||||
|     RETURN_IF_ERROR(IsSupported(context, node, registration)); | ||||
|     return (IsAllFloatTensors(context, node->inputs) && | ||||
|             IsAllFloatTensors(context, node->outputs)) | ||||
|                ? OkStatus() | ||||
|                : FailedPreconditionError( | ||||
|                      "OP is supported, but tensor type isn't matched!"); | ||||
|   }; | ||||
| 
 | ||||
|   GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn); | ||||
|   std::set<std::string> unsupported_nodes_info; | ||||
|   auto status = partition_helper.Partition(&unsupported_nodes_info); | ||||
|   if (!status.ok()) { | ||||
|     TF_LITE_KERNEL_LOG(context, status.error_message().c_str()); | ||||
|     return nullptr; | ||||
|   } | ||||
| 
 | ||||
|   // Dispatch to another function if graph has Dequantize nodes.
 | ||||
|   for (int i = 0; i < execution_plan->size; ++i) { | ||||
|     const int node_id = execution_plan->data[i]; | ||||
|     TfLiteNode* node = nullptr; | ||||
|     TfLiteRegistration* registration = nullptr; | ||||
|     auto status = | ||||
|         GetNodeAndRegistration(context, node_id, &node, ®istration); | ||||
|     if (!status.ok()) { | ||||
|       context->ReportError(context, status.error_message().c_str()); | ||||
|       return nullptr; | ||||
|     } | ||||
|     if (registration->builtin_code == kTfLiteBuiltinDequantize && | ||||
|         context->tensors[node->inputs->data[0]].type == | ||||
|             TfLiteType::kTfLiteFloat16) { | ||||
|       return GetOpsToReplaceFromGraphWithDequantize(context); | ||||
|     } | ||||
|   } | ||||
|   // We simply get 1st largest partition, but we could later explore whether
 | ||||
|   // getting more partitions could lead to better performance, i.e. by
 | ||||
|   // parameterizing '1' here.
 | ||||
|   std::vector<int> ops_to_replace = | ||||
|       partition_helper.GetNodesOfFirstNLargestPartitions(1); | ||||
| 
 | ||||
|   // No Dequantize nodes. Iterate through graph and find ops to replace.
 | ||||
|   TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size); | ||||
|   subgraph->size = 0; | ||||
|   std::set<std::string> errors; | ||||
|   for (int i = 0; i < execution_plan->size; ++i) { | ||||
|     const int node_id = execution_plan->data[i]; | ||||
|     TfLiteNode* node; | ||||
|     TfLiteRegistration* registration; | ||||
|     auto status = | ||||
|         GetNodeAndRegistration(context, node_id, &node, ®istration); | ||||
|     if (!status.ok()) { | ||||
|       context->ReportError(context, status.error_message().c_str()); | ||||
|       return nullptr; | ||||
|     } | ||||
|     status = IsSupported(context, node, registration); | ||||
|     if (status.ok() && | ||||
|         // TODO(eignasheva): resolve sub operation support for metal delegate
 | ||||
|         // registration->builtin_code != kTfLiteBuiltinSub &&
 | ||||
|         IsAllFloatTensors(context, node->inputs) && | ||||
|         IsAllFloatTensors(context, node->outputs)) { | ||||
|       if (errors.empty()) subgraph->data[subgraph->size++] = node_id; | ||||
|   if (!unsupported_nodes_info.empty()) { | ||||
|     std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n"); | ||||
|     std::string error_message = absl::StrCat( | ||||
|         "Following operations are not supported by GPU delegate:\n", | ||||
|         unsupported, "\n"); | ||||
|     if (!ops_to_replace.empty()) { | ||||
|       absl::StrAppendFormat( | ||||
|           &error_message, | ||||
|           "%d operations will run on the GPU (first node: " | ||||
|           "%d, last node: %d), and the remaining %d", | ||||
|           ops_to_replace.size(), ops_to_replace.front(), ops_to_replace.back(), | ||||
|           partition_helper.num_total_nodes() - ops_to_replace.size()); | ||||
|     } else { | ||||
|       errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ", | ||||
|                                  status.error_message())); | ||||
|       absl::StrAppend(&error_message, | ||||
|                       "No operations will run on the GPU, and all ", | ||||
|                       partition_helper.num_total_nodes()); | ||||
|     } | ||||
|     absl::StrAppend(&error_message, " operations will run on the CPU."); | ||||
|     TF_LITE_KERNEL_LOG(context, error_message.c_str()); | ||||
|   } | ||||
|   if (!errors.empty()) { | ||||
|     std::string unsupported = absl::StrJoin(errors, "\n"); | ||||
|     std::string error_message = | ||||
|         "Next operations are not supported by GPU delegate:\n" + unsupported + | ||||
|         "\nFirst " + std::to_string(subgraph->size) + | ||||
|         " operations will run on the GPU, and the remaining " + | ||||
|         std::to_string(execution_plan->size - subgraph->size) + " on the CPU."; | ||||
|     context->ReportError(context, error_message.c_str()); | ||||
|   } | ||||
|   return subgraph; | ||||
|   return ConvertVectorToTfLiteIntArray(ops_to_replace); | ||||
| } | ||||
| 
 | ||||
| Status BuildModel(TfLiteContext* context, | ||||
| @ -3047,7 +3169,7 @@ Status BuildModel(TfLiteContext* context, | ||||
|     const auto status = | ||||
|         operations[i]->Parse(tflite_node, registration, graph, &reader); | ||||
|     if (!status.ok()) { | ||||
|       return InternalError(absl::StrCat(GetOpNameByRegistration(registration), | ||||
|       return InternalError(absl::StrCat(GetOpNameByRegistration(*registration), | ||||
|                                         ": ", status.error_message())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -120,9 +120,52 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) { | ||||
|   EXPECT_FALSE(status.ok()); | ||||
| } | ||||
| 
 | ||||
| class InterpreterFp16 { | ||||
| class DelegatedInterpreter { | ||||
|  public: | ||||
|   explicit InterpreterFp16(TfLiteBuiltinOperator op) { | ||||
|   explicit DelegatedInterpreter(int num_nodes) { | ||||
|     exec_plan_ = TfLiteIntArrayCreate(num_nodes); | ||||
|   } | ||||
|   virtual ~DelegatedInterpreter() { | ||||
|     TfLiteIntArrayFree(exec_plan_); | ||||
|     for (auto params : delegate_params_) { | ||||
|       TfLiteIntArrayFree(params.nodes_to_replace); | ||||
|       TfLiteIntArrayFree(params.input_tensors); | ||||
|       TfLiteIntArrayFree(params.output_tensors); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Get the TfLiteContext to be mocked for swapping out functions that have to
 | ||||
|   // be called inside delegate (i.e. in delegat kernel mode).
 | ||||
|   TfLiteContext* context() { return interpreter_.primary_subgraph().context(); } | ||||
| 
 | ||||
|   std::vector<std::pair<TfLiteNode, TfLiteRegistration>>& | ||||
|   nodes_and_registration() { | ||||
|     return interpreter_.primary_subgraph().nodes_and_registration(); | ||||
|   } | ||||
| 
 | ||||
|   TfLiteIntArray* exec_plan() const { return exec_plan_; } | ||||
|   TfLiteDelegateParams* add_delegate_params() { | ||||
|     delegate_params_.push_back(TfLiteDelegateParams()); | ||||
|     return &delegate_params_.back(); | ||||
|   } | ||||
|   TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); } | ||||
|   int num_delegate_params() { return delegate_params_.size(); } | ||||
| 
 | ||||
|  protected: | ||||
|   Interpreter interpreter_; | ||||
| 
 | ||||
|  private: | ||||
|   // The manually-set execution plan for this delegated interpreter.
 | ||||
|   TfLiteIntArray* exec_plan_; | ||||
| 
 | ||||
|   // The TfLiteDelegateParams object that's manually populated inside the mocked
 | ||||
|   // TfLiteContext::PreviewDelegatePartitioning.
 | ||||
|   std::vector<TfLiteDelegateParams> delegate_params_; | ||||
| }; | ||||
| 
 | ||||
| class InterpreterFp16 : public DelegatedInterpreter { | ||||
|  public: | ||||
|   explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) { | ||||
|     void* builtin_data = malloc(sizeof(int)); | ||||
|     EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk); | ||||
|     EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk); | ||||
| @ -187,22 +230,23 @@ class InterpreterFp16 { | ||||
|             3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false), | ||||
|         kTfLiteOk); | ||||
| 
 | ||||
|     exec_plan_ = TfLiteIntArrayCreate(3); | ||||
|     exec_plan_->data[0] = 0; | ||||
|     exec_plan_->data[1] = 1; | ||||
|     exec_plan_->data[2] = 2; | ||||
|     exec_plan()->data[0] = 0; | ||||
|     exec_plan()->data[1] = 1; | ||||
|     exec_plan()->data[2] = 2; | ||||
|   } | ||||
| 
 | ||||
|   ~InterpreterFp16() { TfLiteIntArrayFree(exec_plan_); } | ||||
| 
 | ||||
|   Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } | ||||
|   TfLiteIntArray* exec_plan() const { return exec_plan_; } | ||||
| 
 | ||||
|  private: | ||||
|   Interpreter interpreter_; | ||||
|   TfLiteIntArray* exec_plan_; | ||||
| }; | ||||
| 
 | ||||
| // **NOTE**: we have several interpreter instances created at global scope to
 | ||||
| // test *exactly* the GetOpsToReplace function alone, and not the sequence of
 | ||||
| // function calls that includes GetOpsToReplace when calling
 | ||||
| // ModifyGraphWithDelegate. A TfLiteContext is needed to test GetOpsToReplace,
 | ||||
| // but TfLiteContexts intentionally make it difficult to call certain functions
 | ||||
| // in a non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for
 | ||||
| // details) We create our own GetExecutionPlan, GetNodeAndRegistration and
 | ||||
| // PreviewDelegatePartitioning lambdas inside each test, but we can't use local
 | ||||
| // captures without changing the function signature. Therefore, this test data
 | ||||
| // lives at global scope in order to be accessible inside the lambda.
 | ||||
| 
 | ||||
| InterpreterFp16* interpreter_fp16_add_op = | ||||
|     new InterpreterFp16(kTfLiteBuiltinAdd); | ||||
| 
 | ||||
| @ -218,7 +262,8 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { | ||||
|   //   t0 (FP16) --> Add -> t4
 | ||||
|   //   t2 (FP16) --/
 | ||||
|   //
 | ||||
|   TfLiteContext* context = interpreter_fp16_add_op->GetSubgraph()->context(); | ||||
|   TfLiteContext* context = interpreter_fp16_add_op->context(); | ||||
| 
 | ||||
|   // These functions are meant to be called inside delegates. Swap out
 | ||||
|   // for similar functions to permit direct calling of GetOpsToReplace.
 | ||||
|   context->GetExecutionPlan = [](struct TfLiteContext* context, | ||||
| @ -229,12 +274,30 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { | ||||
|   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, | ||||
|                                        TfLiteNode** node, | ||||
|                                        TfLiteRegistration** registration) { | ||||
|     auto& node_and_reg = interpreter_fp16_add_op->GetSubgraph() | ||||
|                              ->nodes_and_registration()[node_index]; | ||||
|     auto& node_and_reg = | ||||
|         interpreter_fp16_add_op->nodes_and_registration()[node_index]; | ||||
|     *node = &node_and_reg.first; | ||||
|     *registration = &node_and_reg.second; | ||||
|     return kTfLiteOk; | ||||
|   }; | ||||
|   context->PreviewDelegatePartitioning = | ||||
|       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, | ||||
|          TfLiteDelegateParams** partition_params_array, int* num_partitions) { | ||||
|         auto params = interpreter_fp16_add_op->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(3); | ||||
|         params->nodes_to_replace->data[0] = 0; | ||||
|         params->nodes_to_replace->data[1] = 1; | ||||
|         params->nodes_to_replace->data[2] = 2; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(2); | ||||
|         params->input_tensors->data[0] = 0; | ||||
|         params->input_tensors->data[1] = 2; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 4; | ||||
| 
 | ||||
|         *partition_params_array = interpreter_fp16_add_op->delegate_params(); | ||||
|         *num_partitions = interpreter_fp16_add_op->num_delegate_params(); | ||||
|         return kTfLiteOk; | ||||
|       }; | ||||
| 
 | ||||
|   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); | ||||
| 
 | ||||
| @ -251,17 +314,6 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { | ||||
|   TfLiteIntArrayFree(ops_to_replace); | ||||
| } | ||||
| 
 | ||||
| // This interpreter instance is created at global scope to test *exactly*
 | ||||
| // the GetOpsToReplace function alone, and not the sequence of function calls
 | ||||
| // that includes GetOpsToReplace when calling ModifyGraphWithDelegate.
 | ||||
| // A TfLiteContext is needed to test GetOpsToReplace, but TfLiteContexts
 | ||||
| // intentionally make it difficult to call certain functions in a
 | ||||
| // non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for details)
 | ||||
| // We create our own GetExecutionPlan and GetNodeAndRegistration lambdas
 | ||||
| // inside each test, but we can't use local captures without changing the
 | ||||
| // function signature. Therefore, this test data lives at global scope
 | ||||
| // in order to be accessible inside the lambda.
 | ||||
| 
 | ||||
| InterpreterFp16* interpreter_fp16_gt_op = | ||||
|     new InterpreterFp16(kTfLiteBuiltinGreater); | ||||
| 
 | ||||
| @ -274,7 +326,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { | ||||
|   // Because there is no GPU equivalent for the Greater op, we don't prune
 | ||||
|   // the Dequantize nodes.
 | ||||
| 
 | ||||
|   TfLiteContext* context = interpreter_fp16_gt_op->GetSubgraph()->context(); | ||||
|   TfLiteContext* context = interpreter_fp16_gt_op->context(); | ||||
|   // These functions are meant to be called inside delegates. Swap out
 | ||||
|   // for similar functions to permit direct calling of GetOpsToReplace.
 | ||||
|   context->GetExecutionPlan = [](struct TfLiteContext* context, | ||||
| @ -285,12 +337,37 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { | ||||
|   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, | ||||
|                                        TfLiteNode** node, | ||||
|                                        TfLiteRegistration** registration) { | ||||
|     auto& node_and_reg = interpreter_fp16_gt_op->GetSubgraph() | ||||
|                              ->nodes_and_registration()[node_index]; | ||||
|     auto& node_and_reg = | ||||
|         interpreter_fp16_gt_op->nodes_and_registration()[node_index]; | ||||
|     *node = &node_and_reg.first; | ||||
|     *registration = &node_and_reg.second; | ||||
|     return kTfLiteOk; | ||||
|   }; | ||||
|   context->PreviewDelegatePartitioning = | ||||
|       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, | ||||
|          TfLiteDelegateParams** partition_params_array, int* num_partitions) { | ||||
|         auto params = interpreter_fp16_gt_op->add_delegate_params(); | ||||
|         // First partition for DequantNode (t0->t1)
 | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 0; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 0; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 1; | ||||
| 
 | ||||
|         // Second partition for DequantNode (t2->t3)
 | ||||
|         params = interpreter_fp16_add_op->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 0; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 0; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 1; | ||||
| 
 | ||||
|         *partition_params_array = interpreter_fp16_gt_op->delegate_params(); | ||||
|         *num_partitions = interpreter_fp16_gt_op->num_delegate_params(); | ||||
|         return kTfLiteOk; | ||||
|       }; | ||||
| 
 | ||||
|   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); | ||||
| 
 | ||||
| @ -309,9 +386,9 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { | ||||
|   TfLiteIntArrayFree(ops_to_replace); | ||||
| } | ||||
| 
 | ||||
| class InterpreterFp32 { | ||||
| class InterpreterFp32 : public DelegatedInterpreter { | ||||
|  public: | ||||
|   InterpreterFp32() { | ||||
|   InterpreterFp32() : DelegatedInterpreter(2) { | ||||
|     void* builtin_data = malloc(sizeof(int)); | ||||
|     EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk); | ||||
|     EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk); | ||||
| @ -363,34 +440,24 @@ class InterpreterFp32 { | ||||
|         interpreter_.SetTensorParametersReadWrite( | ||||
|             2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false), | ||||
|         kTfLiteOk); | ||||
|     exec_plan_ = TfLiteIntArrayCreate(2); | ||||
|     exec_plan_->data[0] = 0; | ||||
|     exec_plan_->data[1] = 1; | ||||
| 
 | ||||
|     exec_plan()->data[0] = 0; | ||||
|     exec_plan()->data[1] = 1; | ||||
|   } | ||||
| 
 | ||||
|   ~InterpreterFp32() { TfLiteIntArrayFree(exec_plan_); } | ||||
| 
 | ||||
|   Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } | ||||
|   TfLiteIntArray* exec_plan() const { return exec_plan_; } | ||||
| 
 | ||||
|  private: | ||||
|   Interpreter interpreter_; | ||||
|   TfLiteIntArray* exec_plan_; | ||||
| }; | ||||
| 
 | ||||
| InterpreterFp32* interpreter_fp32 = new InterpreterFp32(); | ||||
| 
 | ||||
| TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) { | ||||
|   // A graph with a Dequant node with uint8 input
 | ||||
|   // is not pruned. The delegate will attempt to replace it
 | ||||
|   // with a GPU op, but this op is currently not supported on
 | ||||
|   // the GPU. Therefore, the Dequant op and all downstream ops
 | ||||
|   // will be scheduled to run on the CPU.
 | ||||
|   // A graph with a Dequant node with uint8 input is not pruned. As this op is
 | ||||
|   // currently not supported on the GPU. Therefore, the Dequant op will be
 | ||||
|   // scheduled to run on the CPU while the remaining supported op Add on the
 | ||||
|   // GPU.
 | ||||
|   //
 | ||||
|   //   t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3
 | ||||
|   //                              t2 (FP32) --/
 | ||||
|   //
 | ||||
|   TfLiteContext* context = interpreter_fp32->GetSubgraph()->context(); | ||||
|   TfLiteContext* context = interpreter_fp32->context(); | ||||
| 
 | ||||
|   // These functions are meant to be called inside delegates. Swap out
 | ||||
|   // for similar functions to permit direct calling of GetOpsToReplace.
 | ||||
| @ -402,24 +469,43 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) { | ||||
|   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, | ||||
|                                        TfLiteNode** node, | ||||
|                                        TfLiteRegistration** registration) { | ||||
|     auto& node_and_reg = | ||||
|         interpreter_fp32->GetSubgraph()->nodes_and_registration()[node_index]; | ||||
|     auto& node_and_reg = interpreter_fp32->nodes_and_registration()[node_index]; | ||||
|     *node = &node_and_reg.first; | ||||
|     *registration = &node_and_reg.second; | ||||
|     return kTfLiteOk; | ||||
|   }; | ||||
|   context->PreviewDelegatePartitioning = | ||||
|       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, | ||||
|          TfLiteDelegateParams** partition_params_array, int* num_partitions) { | ||||
|         auto params = interpreter_fp32->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 1; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(2); | ||||
|         params->input_tensors->data[0] = 1; | ||||
|         params->input_tensors->data[1] = 2; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 3; | ||||
| 
 | ||||
|         *partition_params_array = interpreter_fp32->delegate_params(); | ||||
|         *num_partitions = interpreter_fp32->num_delegate_params(); | ||||
|         return kTfLiteOk; | ||||
|       }; | ||||
| 
 | ||||
|   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); | ||||
| 
 | ||||
|   // No ops are run on the GPU, since the Dequant op is not pruned and must run
 | ||||
|   // on the CPU.
 | ||||
|   EXPECT_EQ(ops_to_replace->size, 0); | ||||
|   // As the Dequant op is not pruned and the ADD op could run on GPU, we have
 | ||||
|   // 1 partition.
 | ||||
|   EXPECT_EQ(ops_to_replace->size, 1); | ||||
|   // ADD at index 1.
 | ||||
|   EXPECT_EQ(1, ops_to_replace->data[0]); | ||||
| 
 | ||||
|   TfLiteIntArrayFree(ops_to_replace); | ||||
| } | ||||
| 
 | ||||
| class InterpreterMultiNode { | ||||
| class InterpreterMultiNode : public DelegatedInterpreter { | ||||
|  public: | ||||
|   explicit InterpreterMultiNode(bool add_op_first = true) { | ||||
|   explicit InterpreterMultiNode(bool add_op_first = true) | ||||
|       : DelegatedInterpreter(5) { | ||||
|     void* builtin_data = malloc(sizeof(int)); | ||||
|     EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk); | ||||
|     EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk); | ||||
| @ -552,38 +638,29 @@ class InterpreterMultiNode { | ||||
|         interpreter_.SetTensorParametersReadWrite( | ||||
|             7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false), | ||||
|         kTfLiteOk); | ||||
|     exec_plan_ = TfLiteIntArrayCreate(5); | ||||
|     exec_plan_->data[0] = 0; | ||||
|     exec_plan_->data[1] = 1; | ||||
|     exec_plan_->data[2] = 2; | ||||
|     exec_plan_->data[3] = 3; | ||||
|     exec_plan_->data[4] = 4; | ||||
| 
 | ||||
|     exec_plan()->data[0] = 0; | ||||
|     exec_plan()->data[1] = 1; | ||||
|     exec_plan()->data[2] = 2; | ||||
|     exec_plan()->data[3] = 3; | ||||
|     exec_plan()->data[4] = 4; | ||||
|   } | ||||
| 
 | ||||
|   ~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); } | ||||
| 
 | ||||
|   Subgraph* GetSubgraph() { return interpreter_.subgraph(0); } | ||||
|   TfLiteIntArray* exec_plan() const { return exec_plan_; } | ||||
| 
 | ||||
|  private: | ||||
|   Interpreter interpreter_; | ||||
|   TfLiteIntArray* exec_plan_; | ||||
| }; | ||||
| 
 | ||||
| InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode(); | ||||
| 
 | ||||
| TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { | ||||
| TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { | ||||
|   // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
 | ||||
|   // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
 | ||||
|   //   t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
 | ||||
|   //   t1 (FP16) --> Dequant --> t4 (FP32) --/
 | ||||
|   //   t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6
 | ||||
|   //   t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
 | ||||
|   //                                          --\ | ||||
|   //   t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
 | ||||
|   //   t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7
 | ||||
|   //
 | ||||
|   //  OpsToReplace should replace the 'Add' op and the Dequant outputing
 | ||||
|   //  t5, but leave the other Dequant nodes because 'Greater' must run
 | ||||
|   //  on the CPU.
 | ||||
|   TfLiteContext* context = interpreter_mn->GetSubgraph()->context(); | ||||
|   TfLiteContext* context = interpreter_mn->context(); | ||||
| 
 | ||||
|   // These functions are meant to be called inside delegates. Swap out
 | ||||
|   // for similar functions to permit direct calling of GetOpsToReplace.
 | ||||
| @ -595,12 +672,48 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { | ||||
|   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, | ||||
|                                        TfLiteNode** node, | ||||
|                                        TfLiteRegistration** registration) { | ||||
|     auto& node_and_reg = | ||||
|         interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index]; | ||||
|     auto& node_and_reg = interpreter_mn->nodes_and_registration()[node_index]; | ||||
|     *node = &node_and_reg.first; | ||||
|     *registration = &node_and_reg.second; | ||||
|     return kTfLiteOk; | ||||
|   }; | ||||
|   context->PreviewDelegatePartitioning = | ||||
|       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, | ||||
|          TfLiteDelegateParams** partition_params_array, int* num_partitions) { | ||||
|         auto params = interpreter_mn->add_delegate_params(); | ||||
|         // First partition for DequantNode(0)
 | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 0; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 0; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 3; | ||||
| 
 | ||||
|         // Second partition for DequantNode(1)
 | ||||
|         params = interpreter_mn->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 1; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 1; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 4; | ||||
| 
 | ||||
|         // Third partition for DequantNode(1), DequantNode(2), ADD(3)
 | ||||
|         params = interpreter_mn->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(3); | ||||
|         params->nodes_to_replace->data[0] = 1; | ||||
|         params->nodes_to_replace->data[1] = 2; | ||||
|         params->nodes_to_replace->data[2] = 3; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(2); | ||||
|         params->input_tensors->data[0] = 1; | ||||
|         params->input_tensors->data[0] = 3; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 7; | ||||
| 
 | ||||
|         *partition_params_array = interpreter_mn->delegate_params(); | ||||
|         *num_partitions = interpreter_mn->num_delegate_params(); | ||||
|         return kTfLiteOk; | ||||
|       }; | ||||
| 
 | ||||
|   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); | ||||
| 
 | ||||
| @ -624,19 +737,22 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) { | ||||
| 
 | ||||
| InterpreterMultiNode* interpreter_mn2 = | ||||
|     new InterpreterMultiNode(/*add_op_first=*/false); | ||||
| 
 | ||||
| TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) { | ||||
|   // A graph with three Dequant nodes feeding two ops, 'Greater' and 'Add'.
 | ||||
| TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { | ||||
|   // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
 | ||||
|   // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
 | ||||
|   //   t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
 | ||||
|   //   t1 (FP16) --> Dequant --> t4 (FP32) --/
 | ||||
|   //   t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
 | ||||
|   //   t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
 | ||||
|   //                                          --\ | ||||
|   //   t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
 | ||||
|   //   t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
 | ||||
|   //
 | ||||
|   // 'Greater' comes first in the execution plan though, so Add should not
 | ||||
|   // be scheduled to run on the Gpu. Further, it's inputs should remain t4
 | ||||
|   // and t5.
 | ||||
|   TfLiteContext* context = interpreter_mn->GetSubgraph()->context(); | ||||
|   // Note: the graph dependency is exactly same w/ that in
 | ||||
|   // GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported
 | ||||
|   // 'Greater' op appears first in the execution plan. Despite this,
 | ||||
|   // OpsToReplace should still replace the 'Add' op and the Dequant outputing
 | ||||
|   //  t5, but leave the other Dequant nodes because 'Greater' must run
 | ||||
|   //  on the CPU.
 | ||||
| 
 | ||||
|   TfLiteContext* context = interpreter_mn2->context(); | ||||
| 
 | ||||
|   // These functions are meant to be called inside delegates. Swap out
 | ||||
|   // for similar functions to permit direct calling of GetOpsToReplace.
 | ||||
| @ -648,27 +764,67 @@ TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) { | ||||
|   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, | ||||
|                                        TfLiteNode** node, | ||||
|                                        TfLiteRegistration** registration) { | ||||
|     auto& node_and_reg = | ||||
|         interpreter_mn2->GetSubgraph()->nodes_and_registration()[node_index]; | ||||
|     auto& node_and_reg = interpreter_mn2->nodes_and_registration()[node_index]; | ||||
|     *node = &node_and_reg.first; | ||||
|     *registration = &node_and_reg.second; | ||||
|     return kTfLiteOk; | ||||
|   }; | ||||
| 
 | ||||
|   context->PreviewDelegatePartitioning = | ||||
|       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, | ||||
|          TfLiteDelegateParams** partition_params_array, int* num_partitions) { | ||||
|         auto params = interpreter_mn2->add_delegate_params(); | ||||
|         // First partition for DequantNode(0)
 | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 0; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 0; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 3; | ||||
| 
 | ||||
|         // Second partition for DequantNode(1)
 | ||||
|         params = interpreter_mn2->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(1); | ||||
|         params->nodes_to_replace->data[0] = 1; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->input_tensors->data[0] = 1; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 4; | ||||
| 
 | ||||
|         // Third partition for DequantNode(1), DequantNode(2), ADD(4)
 | ||||
|         params = interpreter_mn2->add_delegate_params(); | ||||
|         params->nodes_to_replace = TfLiteIntArrayCreate(3); | ||||
|         params->nodes_to_replace->data[0] = 1; | ||||
|         params->nodes_to_replace->data[1] = 2; | ||||
|         params->nodes_to_replace->data[2] = 4; | ||||
|         params->input_tensors = TfLiteIntArrayCreate(2); | ||||
|         params->input_tensors->data[0] = 1; | ||||
|         params->input_tensors->data[0] = 3; | ||||
|         params->output_tensors = TfLiteIntArrayCreate(1); | ||||
|         params->output_tensors->data[0] = 7; | ||||
| 
 | ||||
|         *partition_params_array = interpreter_mn2->delegate_params(); | ||||
|         *num_partitions = interpreter_mn2->num_delegate_params(); | ||||
|         return kTfLiteOk; | ||||
|       }; | ||||
| 
 | ||||
|   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); | ||||
| 
 | ||||
|   // Verify that no ops will be replaced.
 | ||||
|   EXPECT_EQ(ops_to_replace->size, 0); | ||||
|   EXPECT_EQ(ops_to_replace->size, 2); | ||||
|   // Op at index 2 is the Dequant op (t3 -> t5).
 | ||||
|   EXPECT_EQ(ops_to_replace->data[0], 2); | ||||
|   // Op at index 4 is the Add op.
 | ||||
|   EXPECT_EQ(ops_to_replace->data[1], 4); | ||||
| 
 | ||||
|   TfLiteNode* node = nullptr; | ||||
|   TfLiteRegistration* registration = nullptr; | ||||
|   // Verify that Add op has fp32 inputs.
 | ||||
|   context->GetNodeAndRegistration(context, 4, &node, ®istration); | ||||
|   EXPECT_EQ(registration->builtin_code, kTfLiteBuiltinAdd); | ||||
|   // Verify that Add op has fp16 inputs.
 | ||||
|   context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node, | ||||
|                                   ®istration); | ||||
|   EXPECT_EQ(context->tensors[node->inputs->data[0]].type, | ||||
|             TfLiteType::kTfLiteFloat32); | ||||
|             TfLiteType::kTfLiteFloat16); | ||||
|   EXPECT_EQ(context->tensors[node->inputs->data[1]].type, | ||||
|             TfLiteType::kTfLiteFloat32); | ||||
|             TfLiteType::kTfLiteFloat16); | ||||
|   TfLiteIntArrayFree(ops_to_replace); | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -17,6 +17,7 @@ limitations under the License. | ||||
| #include <complex> | ||||
| #include <cstring> | ||||
| 
 | ||||
| #include "tensorflow/lite/builtin_ops.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| @ -131,4 +132,15 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) { | ||||
|          registration.invoke == &UnresolvedOpInvoke; | ||||
| } | ||||
| 
 | ||||
| std::string GetOpNameByRegistration(const TfLiteRegistration& registration) { | ||||
|   auto op = registration.builtin_code; | ||||
|   std::string result = | ||||
|       EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op)); | ||||
|   if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) && | ||||
|       registration.custom_name) { | ||||
|     result += " " + std::string(registration.custom_name); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| }  // namespace tflite
 | ||||
|  | ||||
| @ -21,6 +21,7 @@ limitations under the License. | ||||
| #ifndef TENSORFLOW_LITE_UTIL_H_ | ||||
| #define TENSORFLOW_LITE_UTIL_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| @ -73,6 +74,8 @@ TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name); | ||||
| // Checks whether the provided op is an unresolved custom op.
 | ||||
| bool IsUnresolvedCustomOp(const TfLiteRegistration& registration); | ||||
| 
 | ||||
| // Returns a descriptive name with the given op TfLiteRegistration.
 | ||||
| std::string GetOpNameByRegistration(const TfLiteRegistration& registration); | ||||
| }  // namespace tflite
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_LITE_UTIL_H_
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ limitations under the License. | ||||
| #include <gmock/gmock.h> | ||||
| #include <gtest/gtest.h> | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace { | ||||
| @ -90,6 +91,32 @@ TEST(CombineHashes, TestHashOutputsDifferent) { | ||||
|   EXPECT_NE(output1, output2); | ||||
| } | ||||
| 
 | ||||
| TEST(GetOpNameByRegistration, ValidBuiltinCode) { | ||||
|   TfLiteRegistration registration; | ||||
|   registration.builtin_code = tflite::BuiltinOperator_ADD; | ||||
|   const auto op_name = GetOpNameByRegistration(registration); | ||||
|   EXPECT_EQ("ADD", op_name); | ||||
| } | ||||
| 
 | ||||
| TEST(GetOpNameByRegistration, InvalidBuiltinCode) { | ||||
|   TfLiteRegistration registration; | ||||
|   registration.builtin_code = -1; | ||||
|   const auto op_name = GetOpNameByRegistration(registration); | ||||
|   EXPECT_EQ("", op_name); | ||||
| } | ||||
| 
 | ||||
| TEST(GetOpNameByRegistration, CustomName) { | ||||
|   TfLiteRegistration registration; | ||||
|   registration.builtin_code = tflite::BuiltinOperator_CUSTOM; | ||||
|   registration.custom_name = "TestOp"; | ||||
|   auto op_name = GetOpNameByRegistration(registration); | ||||
|   EXPECT_EQ("CUSTOM TestOp", op_name); | ||||
| 
 | ||||
|   registration.builtin_code = tflite::BuiltinOperator_DELEGATE; | ||||
|   registration.custom_name = "TestDelegate"; | ||||
|   op_name = GetOpNameByRegistration(registration); | ||||
|   EXPECT_EQ("DELEGATE TestDelegate", op_name); | ||||
| } | ||||
| }  // namespace
 | ||||
| }  // namespace tflite
 | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user