Replace usage of GetFirstNLargestPartitions
with GetNodesOfFirstNLargestPartitions
This removes common logic to iterate through partitions, and flattening supported nodes into one vector. PiperOrigin-RevId: 313746666 Change-Id: I703bea87cac0ea0ffe25d5a8e11e052465e15f34
This commit is contained in:
parent
b2cc6c66a8
commit
4d5f0144c7
@ -4188,8 +4188,7 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
|||||||
auto* delegate_data = static_cast<Data*>(delegate->data_);
|
auto* delegate_data = static_cast<Data*>(delegate->data_);
|
||||||
// The first entry in the array is the element count
|
// The first entry in the array is the element count
|
||||||
|
|
||||||
auto supported_nodes_int_array =
|
auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
|
||||||
delegates::BuildTfLiteIntArray(supported_nodes);
|
|
||||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||||
context, supported_nodes_int_array.get(), params_array, num_partitions));
|
context, supported_nodes_int_array.get(), params_array, num_partitions));
|
||||||
// For each partition check if which nodes are actually supported by the
|
// For each partition check if which nodes are actually supported by the
|
||||||
@ -4222,7 +4221,7 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
|||||||
// We changed the set of nodes to delegate this will create a different
|
// We changed the set of nodes to delegate this will create a different
|
||||||
// partitioning layout.
|
// partitioning layout.
|
||||||
auto device_sup_nodes_int_array =
|
auto device_sup_nodes_int_array =
|
||||||
delegates::BuildTfLiteIntArray(*device_supported_nodes);
|
BuildTfLiteIntArray(*device_supported_nodes);
|
||||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||||
context, device_sup_nodes_int_array.get(), params_array,
|
context, device_sup_nodes_int_array.get(), params_array,
|
||||||
num_partitions));
|
num_partitions));
|
||||||
@ -4419,8 +4418,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
|||||||
&num_partitions, ¶ms_array, nnapi_errno));
|
&num_partitions, ¶ms_array, nnapi_errno));
|
||||||
} else {
|
} else {
|
||||||
nodes_to_delegate = supported_nodes;
|
nodes_to_delegate = supported_nodes;
|
||||||
auto supported_nodes_int_array =
|
auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
|
||||||
delegates::BuildTfLiteIntArray(supported_nodes);
|
|
||||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||||
context, supported_nodes_int_array.get(), ¶ms_array,
|
context, supported_nodes_int_array.get(), ¶ms_array,
|
||||||
&num_partitions));
|
&num_partitions));
|
||||||
@ -4437,8 +4435,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
|||||||
} else {
|
} else {
|
||||||
// Request TFLite to partition the graph and make kernels
|
// Request TFLite to partition the graph and make kernels
|
||||||
// for each independent node sub set a new nnapi_delegate_kernel.
|
// for each independent node sub set a new nnapi_delegate_kernel.
|
||||||
auto nodes_to_delegate_int_array =
|
auto nodes_to_delegate_int_array = BuildTfLiteIntArray(nodes_to_delegate);
|
||||||
delegates::BuildTfLiteIntArray(nodes_to_delegate);
|
|
||||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||||
context, nnapi_delegate_kernel, nodes_to_delegate_int_array.get(),
|
context, nnapi_delegate_kernel, nodes_to_delegate_int_array.get(),
|
||||||
delegate);
|
delegate);
|
||||||
|
@ -46,14 +46,6 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
|
||||||
const std::vector<int>& data) {
|
|
||||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
|
|
||||||
TfLiteIntArrayCreate(data.size()));
|
|
||||||
std::copy(data.begin(), data.end(), result->data);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus GraphPartitionHelper::Partition(
|
TfLiteStatus GraphPartitionHelper::Partition(
|
||||||
std::set<std::string>* unsupported_nodes_info) {
|
std::set<std::string>* unsupported_nodes_info) {
|
||||||
const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
|
const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
|
||||||
@ -103,6 +95,19 @@ GraphPartitionHelper::GetFirstNLargestPartitions(
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
|
int n, int min_nodes_per_partition) {
|
||||||
|
auto first_n_partitions =
|
||||||
|
GetFirstNLargestPartitions(n, min_nodes_per_partition);
|
||||||
|
std::vector<int> ops_to_replace;
|
||||||
|
for (const auto p : first_n_partitions) {
|
||||||
|
auto nodes = p->nodes_to_replace;
|
||||||
|
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||||
|
nodes->data + nodes->size);
|
||||||
|
}
|
||||||
|
return ops_to_replace;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
||||||
std::set<std::string>* unsupported_nodes_info) {
|
std::set<std::string>* unsupported_nodes_info) {
|
||||||
if (!is_node_supported_fn_) return kTfLiteOk;
|
if (!is_node_supported_fn_) return kTfLiteOk;
|
||||||
@ -155,23 +160,12 @@ TfLiteStatus FP16GraphPartitionHelper::Partition(
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitions(
|
std::vector<int>
|
||||||
int n, int min_nodes_per_partition,
|
FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
std::vector<TfLiteDelegateParams*>* partitions) {
|
int n, int min_nodes_per_partition) {
|
||||||
// We first get partitions to reduce the number of nodes to be checked in
|
std::vector<int> ops_to_replace =
|
||||||
// deciding which dequant ops could actually be replaced. And then we
|
GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
// remap input-tensor to dequant nodes' inputs and remove those
|
n, min_nodes_per_partition);
|
||||||
// to-be-reserved dequant nodes.
|
|
||||||
auto first_nps = GetFirstNLargestPartitions(n, min_nodes_per_partition);
|
|
||||||
if (partitions != nullptr) {
|
|
||||||
*partitions = first_nps;
|
|
||||||
}
|
|
||||||
std::vector<int> ops_to_replace;
|
|
||||||
for (const auto p : first_nps) {
|
|
||||||
auto nodes = p->nodes_to_replace;
|
|
||||||
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
|
||||||
nodes->data + nodes->size);
|
|
||||||
}
|
|
||||||
RemapInputTensors(ops_to_replace);
|
RemapInputTensors(ops_to_replace);
|
||||||
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
||||||
return ops_to_replace;
|
return ops_to_replace;
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -41,9 +40,6 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
|||||||
TfLiteTensor** new_tensor,
|
TfLiteTensor** new_tensor,
|
||||||
int* new_tensor_index);
|
int* new_tensor_index);
|
||||||
|
|
||||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
|
||||||
const std::vector<int>& data);
|
|
||||||
|
|
||||||
using IsNodeSupportedFn =
|
using IsNodeSupportedFn =
|
||||||
std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
|
std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
|
||||||
std::string* unsupported_details)>;
|
std::string* unsupported_details)>;
|
||||||
@ -76,10 +72,21 @@ class GraphPartitionHelper {
|
|||||||
// Note that partitions are ranked according to the number of nodes that
|
// Note that partitions are ranked according to the number of nodes that
|
||||||
// a partition has, and the returned TfLiteDelegateParams objects are *owned*
|
// a partition has, and the returned TfLiteDelegateParams objects are *owned*
|
||||||
// by the TfLite runtime.
|
// by the TfLite runtime.
|
||||||
|
// TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions
|
||||||
std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(
|
std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(
|
||||||
int n = std::numeric_limits<int>::max(),
|
int n = std::numeric_limits<int>::max(),
|
||||||
int min_nodes_per_partition = 0) const;
|
int min_nodes_per_partition = 0) const;
|
||||||
|
|
||||||
|
// Returns a list of node indices of all nodes from the first n largest
|
||||||
|
// partitions. If there are fewer paritions than n, all nodes will be
|
||||||
|
// returned. The partition is ranked according to the number of nodes.
|
||||||
|
std::vector<int> GetNodesOfFirstNLargestPartitions(
|
||||||
|
int n = std::numeric_limits<int>::max(),
|
||||||
|
int min_nodes_per_partition = 0) {
|
||||||
|
// Separated implementation that can be overrided, to preserve default value
|
||||||
|
return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition);
|
||||||
|
}
|
||||||
|
|
||||||
int num_total_nodes() const { return num_total_nodes_; }
|
int num_total_nodes() const { return num_total_nodes_; }
|
||||||
int num_partitions() const { return partitions_.size(); }
|
int num_partitions() const { return partitions_.size(); }
|
||||||
|
|
||||||
@ -90,6 +97,8 @@ class GraphPartitionHelper {
|
|||||||
return is_node_supported_fn_(context, node, registration,
|
return is_node_supported_fn_(context, node, registration,
|
||||||
unsupported_details);
|
unsupported_details);
|
||||||
}
|
}
|
||||||
|
virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
|
int n, int min_nodes_per_partition);
|
||||||
|
|
||||||
TfLiteContext* const context_ = nullptr;
|
TfLiteContext* const context_ = nullptr;
|
||||||
|
|
||||||
@ -121,9 +130,6 @@ class GraphPartitionHelper {
|
|||||||
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
|
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
|
||||||
// addition to supported nodes for the delegate, when the DEQUANTIZE node's
|
// addition to supported nodes for the delegate, when the DEQUANTIZE node's
|
||||||
// output is an input to the kernel that supports FP16 input.
|
// output is an input to the kernel that supports FP16 input.
|
||||||
// Noth that you have to use `GetNodesOfFirstNLargestPartitions` instead of
|
|
||||||
// superclass' `GetFirstNLargestPartitions` to do actual remapping of FP16
|
|
||||||
// inputs.
|
|
||||||
class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||||
public:
|
public:
|
||||||
FP16GraphPartitionHelper(TfLiteContext* context,
|
FP16GraphPartitionHelper(TfLiteContext* context,
|
||||||
@ -133,20 +139,15 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
|||||||
TfLiteStatus Partition(
|
TfLiteStatus Partition(
|
||||||
std::set<std::string>* unsupported_nodes_info) override;
|
std::set<std::string>* unsupported_nodes_info) override;
|
||||||
|
|
||||||
// Returns a list of node indices of all nodes from the first n largest
|
|
||||||
// partitions. If there are fewer paritions than n, all nodes will be
|
|
||||||
// returned. The partition is ranked according to the number of nodes.
|
|
||||||
// TODO(b/156707497): Add this to superclass besides
|
|
||||||
// GetFirstNLargestPartitions (one that returns partitions instead of nodes)
|
|
||||||
std::vector<int> GetNodesOfFirstNLargestPartitions(
|
|
||||||
int n, int min_nodes_per_partition = 0,
|
|
||||||
std::vector<TfLiteDelegateParams*>* partitions = nullptr);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||||
TfLiteRegistration* registration, int node_id,
|
TfLiteRegistration* registration, int node_id,
|
||||||
std::string* unsupported_details) override;
|
std::string* unsupported_details) override;
|
||||||
|
|
||||||
|
// This will remap input tensors by removing FP16 to FP32 dequantized tensors.
|
||||||
|
std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
|
int n, int min_nodes_per_partition) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
||||||
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/builtin_ops.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/delegates/utils.h"
|
#include "tensorflow/lite/delegates/utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
@ -86,31 +87,19 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context,
|
|||||||
delegates::GraphPartitionHelper helper(context, node_supported_fn);
|
delegates::GraphPartitionHelper helper(context, node_supported_fn);
|
||||||
TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
|
TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
|
||||||
|
|
||||||
const auto delegate_partitions = helper.GetFirstNLargestPartitions();
|
std::vector<int> supported_nodes = helper.GetNodesOfFirstNLargestPartitions();
|
||||||
|
|
||||||
// To avoid creating a new TfLiteIntArray and free it later, we reserve one
|
|
||||||
// element to represent TfLiteIntArray.size which is the 1st element of
|
|
||||||
// TfLiteIntArray C struct.
|
|
||||||
std::vector<int> supported_nodes(1);
|
|
||||||
for (const auto partition : delegate_partitions) {
|
|
||||||
auto* nodes = partition->nodes_to_replace;
|
|
||||||
supported_nodes.insert(supported_nodes.end(), nodes->data,
|
|
||||||
nodes->data + nodes->size);
|
|
||||||
}
|
|
||||||
// Set first element to the number of nodes to replace.
|
|
||||||
supported_nodes[0] = supported_nodes.size() - 1;
|
|
||||||
|
|
||||||
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
||||||
"%s delegate: %d nodes delegated out of %d nodes with "
|
"%s delegate: %d nodes delegated out of %d nodes with "
|
||||||
"%d partitions.\n",
|
"%d partitions.\n",
|
||||||
delegate->name(), supported_nodes[0],
|
delegate->name(), supported_nodes.size(),
|
||||||
helper.num_total_nodes(), delegate_partitions.size());
|
helper.num_total_nodes(), helper.num_partitions());
|
||||||
TfLiteRegistration delegate_kernel_registration =
|
TfLiteRegistration delegate_kernel_registration =
|
||||||
GetDelegateKernelRegistration(delegate);
|
GetDelegateKernelRegistration(delegate);
|
||||||
|
|
||||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||||
context, delegate_kernel_registration,
|
context, delegate_kernel_registration,
|
||||||
reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), base_delegate);
|
BuildTfLiteIntArray(supported_nodes).get(), base_delegate);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -233,15 +233,15 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||||||
delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn);
|
delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn);
|
||||||
TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr));
|
TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr));
|
||||||
|
|
||||||
std::vector<TfLiteDelegateParams*> partitions;
|
|
||||||
std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions(
|
std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions(
|
||||||
params->max_delegated_partitions, params->min_nodes_per_partition, &partitions);
|
params->max_delegated_partitions, params->min_nodes_per_partition);
|
||||||
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
||||||
"CoreML delegate: %d nodes delegated out of %d nodes, "
|
"CoreML delegate: %d nodes delegated out of %d nodes, "
|
||||||
"with %d partitions.\n",
|
"with %d partitions.\n",
|
||||||
delegated_nodes.size(), partition_helper.num_total_nodes(), partitions.size());
|
delegated_nodes.size(), partition_helper.num_total_nodes(),
|
||||||
|
partition_helper.num_partitions());
|
||||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||||
context, GetCoreMlKernelRegistration(), delegates::BuildTfLiteIntArray(delegated_nodes).get(),
|
context, GetCoreMlKernelRegistration(), BuildTfLiteIntArray(delegated_nodes).get(),
|
||||||
delegate);
|
delegate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,23 +157,12 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||||||
|
|
||||||
TfLiteHexagonDelegateOptions* params =
|
TfLiteHexagonDelegateOptions* params =
|
||||||
static_cast<TfLiteHexagonDelegateOptions*>(delegate->data_);
|
static_cast<TfLiteHexagonDelegateOptions*>(delegate->data_);
|
||||||
const auto delegate_partitions = helper.GetFirstNLargestPartitions(
|
std::vector<int> supported_nodes = helper.GetNodesOfFirstNLargestPartitions(
|
||||||
params->max_delegated_partitions, params->min_nodes_per_partition);
|
params->max_delegated_partitions, params->min_nodes_per_partition);
|
||||||
|
|
||||||
// To avoid creating a new TfLiteIntArray and free it later, we reserve one
|
|
||||||
// element to represent TfLiteIntArray.size which is the 1st element of
|
|
||||||
// TfLiteIntArray C struct.
|
|
||||||
std::vector<int> supported_nodes(1);
|
|
||||||
for (const auto partition : delegate_partitions) {
|
|
||||||
auto* nodes = partition->nodes_to_replace;
|
|
||||||
supported_nodes.insert(supported_nodes.end(), nodes->data,
|
|
||||||
nodes->data + nodes->size);
|
|
||||||
}
|
|
||||||
// Set first element to the number of nodes to replace.
|
|
||||||
supported_nodes[0] = supported_nodes.size() - 1;
|
|
||||||
auto* hexagon_delegate = static_cast<HexagonDelegate*>(delegate);
|
auto* hexagon_delegate = static_cast<HexagonDelegate*>(delegate);
|
||||||
// Make sure dynamic batch is requested on fully delegated graph only.
|
// Make sure dynamic batch is requested on fully delegated graph only.
|
||||||
if (supported_nodes[0] != helper.num_total_nodes() &&
|
if (supported_nodes.size() != helper.num_total_nodes() &&
|
||||||
hexagon_delegate != nullptr &&
|
hexagon_delegate != nullptr &&
|
||||||
hexagon_delegate->params()->enable_dynamic_batch_size) {
|
hexagon_delegate->params()->enable_dynamic_batch_size) {
|
||||||
TF_LITE_KERNEL_LOG(
|
TF_LITE_KERNEL_LOG(
|
||||||
@ -183,12 +172,12 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||||||
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
||||||
"Hexagon delegate: %d nodes delegated out of %d nodes with "
|
"Hexagon delegate: %d nodes delegated out of %d nodes with "
|
||||||
"%d partitions.\n",
|
"%d partitions.\n",
|
||||||
supported_nodes[0], helper.num_total_nodes(),
|
supported_nodes.size(), helper.num_total_nodes(),
|
||||||
delegate_partitions.size());
|
helper.num_partitions());
|
||||||
|
|
||||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||||
context, GetHexagonKernelRegistration(),
|
context, GetHexagonKernelRegistration(),
|
||||||
reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), delegate);
|
BuildTfLiteIntArray(supported_nodes).get(), delegate);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteDelegate* CreateDelegate(const TfLiteHexagonDelegateOptions* params) {
|
TfLiteDelegate* CreateDelegate(const TfLiteHexagonDelegateOptions* params) {
|
||||||
|
@ -38,6 +38,14 @@ bool IsFlexOp(const char* custom_name) {
|
|||||||
strlen(kFlexCustomCodePrefix)) == 0;
|
strlen(kFlexCustomCodePrefix)) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||||
|
const std::vector<int>& data) {
|
||||||
|
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
|
||||||
|
TfLiteIntArrayCreate(data.size()));
|
||||||
|
std::copy(data.begin(), data.end(), result->data);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
|
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
|
||||||
return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
|
return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
|
||||||
input.data());
|
input.data());
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_UTIL_H_
|
#ifndef TENSORFLOW_LITE_UTIL_H_
|
||||||
#define TENSORFLOW_LITE_UTIL_H_
|
#define TENSORFLOW_LITE_UTIL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -60,6 +61,11 @@ struct TfLiteIntArrayDeleter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper for Building TfLiteIntArray that is wrapped in a unique_ptr,
|
||||||
|
// So that it is automatically freed when it goes out of the scope.
|
||||||
|
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||||
|
const std::vector<int>& data);
|
||||||
|
|
||||||
// Populates the size in bytes of a type into `bytes`. Returns kTfLiteOk for
|
// Populates the size in bytes of a type into `bytes`. Returns kTfLiteOk for
|
||||||
// valid types, and kTfLiteError otherwise.
|
// valid types, and kTfLiteError otherwise.
|
||||||
TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
||||||
|
Loading…
Reference in New Issue
Block a user