Replace usage of GetFirstNLargestPartitions with GetNodesOfFirstNLargestPartitions

This removes common logic to iterate through partitions, and flattening supported nodes into one vector.

PiperOrigin-RevId: 313746666
Change-Id: I703bea87cac0ea0ffe25d5a8e11e052465e15f34
This commit is contained in:
Taehee Jeong 2020-05-29 02:31:43 -07:00 committed by TensorFlower Gardener
parent b2cc6c66a8
commit 4d5f0144c7
8 changed files with 68 additions and 84 deletions

View File

@ -4188,8 +4188,7 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
auto* delegate_data = static_cast<Data*>(delegate->data_); auto* delegate_data = static_cast<Data*>(delegate->data_);
// The first entry in the array is the element count // The first entry in the array is the element count
auto supported_nodes_int_array = auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
delegates::BuildTfLiteIntArray(supported_nodes);
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning( TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
context, supported_nodes_int_array.get(), params_array, num_partitions)); context, supported_nodes_int_array.get(), params_array, num_partitions));
// For each partition check if which nodes are actually supported by the // For each partition check if which nodes are actually supported by the
@ -4222,7 +4221,7 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
// We changed the set of nodes to delegate this will create a different // We changed the set of nodes to delegate this will create a different
// partitioning layout. // partitioning layout.
auto device_sup_nodes_int_array = auto device_sup_nodes_int_array =
delegates::BuildTfLiteIntArray(*device_supported_nodes); BuildTfLiteIntArray(*device_supported_nodes);
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning( TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
context, device_sup_nodes_int_array.get(), params_array, context, device_sup_nodes_int_array.get(), params_array,
num_partitions)); num_partitions));
@ -4419,8 +4418,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
&num_partitions, &params_array, nnapi_errno)); &num_partitions, &params_array, nnapi_errno));
} else { } else {
nodes_to_delegate = supported_nodes; nodes_to_delegate = supported_nodes;
auto supported_nodes_int_array = auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
delegates::BuildTfLiteIntArray(supported_nodes);
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning( TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
context, supported_nodes_int_array.get(), &params_array, context, supported_nodes_int_array.get(), &params_array,
&num_partitions)); &num_partitions));
@ -4437,8 +4435,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
} else { } else {
// Request TFLite to partition the graph and make kernels // Request TFLite to partition the graph and make kernels
// for each independent node sub set a new nnapi_delegate_kernel. // for each independent node sub set a new nnapi_delegate_kernel.
auto nodes_to_delegate_int_array = auto nodes_to_delegate_int_array = BuildTfLiteIntArray(nodes_to_delegate);
delegates::BuildTfLiteIntArray(nodes_to_delegate);
return context->ReplaceNodeSubsetsWithDelegateKernels( return context->ReplaceNodeSubsetsWithDelegateKernels(
context, nnapi_delegate_kernel, nodes_to_delegate_int_array.get(), context, nnapi_delegate_kernel, nodes_to_delegate_int_array.get(),
delegate); delegate);

View File

@ -46,14 +46,6 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
return kTfLiteOk; return kTfLiteOk;
} }
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
const std::vector<int>& data) {
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
TfLiteIntArrayCreate(data.size()));
std::copy(data.begin(), data.end(), result->data);
return result;
}
TfLiteStatus GraphPartitionHelper::Partition( TfLiteStatus GraphPartitionHelper::Partition(
std::set<std::string>* unsupported_nodes_info) { std::set<std::string>* unsupported_nodes_info) {
const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info); const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
@ -103,6 +95,19 @@ GraphPartitionHelper::GetFirstNLargestPartitions(
return results; return results;
} }
std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
int n, int min_nodes_per_partition) {
auto first_n_partitions =
GetFirstNLargestPartitions(n, min_nodes_per_partition);
std::vector<int> ops_to_replace;
for (const auto p : first_n_partitions) {
auto nodes = p->nodes_to_replace;
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
nodes->data + nodes->size);
}
return ops_to_replace;
}
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
std::set<std::string>* unsupported_nodes_info) { std::set<std::string>* unsupported_nodes_info) {
if (!is_node_supported_fn_) return kTfLiteOk; if (!is_node_supported_fn_) return kTfLiteOk;
@ -155,23 +160,12 @@ TfLiteStatus FP16GraphPartitionHelper::Partition(
return status; return status;
} }
std::vector<int> FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitions( std::vector<int>
int n, int min_nodes_per_partition, FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
std::vector<TfLiteDelegateParams*>* partitions) { int n, int min_nodes_per_partition) {
// We first get partitions to reduce the number of nodes to be checked in std::vector<int> ops_to_replace =
// deciding which dequant ops could actually be replaced. And then we GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
// remap input-tensor to dequant nodes' inputs and remove those n, min_nodes_per_partition);
// to-be-reserved dequant nodes.
auto first_nps = GetFirstNLargestPartitions(n, min_nodes_per_partition);
if (partitions != nullptr) {
*partitions = first_nps;
}
std::vector<int> ops_to_replace;
for (const auto p : first_nps) {
auto nodes = p->nodes_to_replace;
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
nodes->data + nodes->size);
}
RemapInputTensors(ops_to_replace); RemapInputTensors(ops_to_replace);
RemoveReservedDequantsFromNodes(&ops_to_replace); RemoveReservedDequantsFromNodes(&ops_to_replace);
return ops_to_replace; return ops_to_replace;

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -41,9 +40,6 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
TfLiteTensor** new_tensor, TfLiteTensor** new_tensor,
int* new_tensor_index); int* new_tensor_index);
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
const std::vector<int>& data);
using IsNodeSupportedFn = using IsNodeSupportedFn =
std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*, std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
std::string* unsupported_details)>; std::string* unsupported_details)>;
@ -76,10 +72,21 @@ class GraphPartitionHelper {
// Note that partitions are ranked according to the number of nodes that // Note that partitions are ranked according to the number of nodes that
// a partition has, and the returned TfLiteDelegateParams objects are *owned* // a partition has, and the returned TfLiteDelegateParams objects are *owned*
// by the TfLite runtime. // by the TfLite runtime.
// TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions
std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions( std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(
int n = std::numeric_limits<int>::max(), int n = std::numeric_limits<int>::max(),
int min_nodes_per_partition = 0) const; int min_nodes_per_partition = 0) const;
// Returns a list of node indices of all nodes from the first n largest
// partitions. If there are fewer paritions than n, all nodes will be
// returned. The partition is ranked according to the number of nodes.
std::vector<int> GetNodesOfFirstNLargestPartitions(
int n = std::numeric_limits<int>::max(),
int min_nodes_per_partition = 0) {
// Separated implementation that can be overrided, to preserve default value
return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition);
}
int num_total_nodes() const { return num_total_nodes_; } int num_total_nodes() const { return num_total_nodes_; }
int num_partitions() const { return partitions_.size(); } int num_partitions() const { return partitions_.size(); }
@ -90,6 +97,8 @@ class GraphPartitionHelper {
return is_node_supported_fn_(context, node, registration, return is_node_supported_fn_(context, node, registration,
unsupported_details); unsupported_details);
} }
virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
int n, int min_nodes_per_partition);
TfLiteContext* const context_ = nullptr; TfLiteContext* const context_ = nullptr;
@ -121,9 +130,6 @@ class GraphPartitionHelper {
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in // While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
// addition to supported nodes for the delegate, when the DEQUANTIZE node's // addition to supported nodes for the delegate, when the DEQUANTIZE node's
// output is an input to the kernel that supports FP16 input. // output is an input to the kernel that supports FP16 input.
// Noth that you have to use `GetNodesOfFirstNLargestPartitions` instead of
// superclass' `GetFirstNLargestPartitions` to do actual remapping of FP16
// inputs.
class FP16GraphPartitionHelper : public GraphPartitionHelper { class FP16GraphPartitionHelper : public GraphPartitionHelper {
public: public:
FP16GraphPartitionHelper(TfLiteContext* context, FP16GraphPartitionHelper(TfLiteContext* context,
@ -133,20 +139,15 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
TfLiteStatus Partition( TfLiteStatus Partition(
std::set<std::string>* unsupported_nodes_info) override; std::set<std::string>* unsupported_nodes_info) override;
// Returns a list of node indices of all nodes from the first n largest
// partitions. If there are fewer paritions than n, all nodes will be
// returned. The partition is ranked according to the number of nodes.
// TODO(b/156707497): Add this to superclass besides
// GetFirstNLargestPartitions (one that returns partitions instead of nodes)
std::vector<int> GetNodesOfFirstNLargestPartitions(
int n, int min_nodes_per_partition = 0,
std::vector<TfLiteDelegateParams*>* partitions = nullptr);
protected: protected:
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration, int node_id, TfLiteRegistration* registration, int node_id,
std::string* unsupported_details) override; std::string* unsupported_details) override;
// This will remap input tensors by removing FP16 to FP32 dequantized tensors.
std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
int n, int min_nodes_per_partition) override;
private: private:
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
// When it's not a dequant op, remap its inputs to the inputs of the preceding // When it's not a dequant op, remap its inputs to the inputs of the preceding

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/delegates/utils.h"
#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/compatibility.h"
@ -86,31 +87,19 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context,
delegates::GraphPartitionHelper helper(context, node_supported_fn); delegates::GraphPartitionHelper helper(context, node_supported_fn);
TF_LITE_ENSURE_STATUS(helper.Partition(nullptr)); TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
const auto delegate_partitions = helper.GetFirstNLargestPartitions(); std::vector<int> supported_nodes = helper.GetNodesOfFirstNLargestPartitions();
// To avoid creating a new TfLiteIntArray and free it later, we reserve one
// element to represent TfLiteIntArray.size which is the 1st element of
// TfLiteIntArray C struct.
std::vector<int> supported_nodes(1);
for (const auto partition : delegate_partitions) {
auto* nodes = partition->nodes_to_replace;
supported_nodes.insert(supported_nodes.end(), nodes->data,
nodes->data + nodes->size);
}
// Set first element to the number of nodes to replace.
supported_nodes[0] = supported_nodes.size() - 1;
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
"%s delegate: %d nodes delegated out of %d nodes with " "%s delegate: %d nodes delegated out of %d nodes with "
"%d partitions.\n", "%d partitions.\n",
delegate->name(), supported_nodes[0], delegate->name(), supported_nodes.size(),
helper.num_total_nodes(), delegate_partitions.size()); helper.num_total_nodes(), helper.num_partitions());
TfLiteRegistration delegate_kernel_registration = TfLiteRegistration delegate_kernel_registration =
GetDelegateKernelRegistration(delegate); GetDelegateKernelRegistration(delegate);
return context->ReplaceNodeSubsetsWithDelegateKernels( return context->ReplaceNodeSubsetsWithDelegateKernels(
context, delegate_kernel_registration, context, delegate_kernel_registration,
reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), base_delegate); BuildTfLiteIntArray(supported_nodes).get(), base_delegate);
} }
} // namespace } // namespace

View File

@ -233,15 +233,15 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn); delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn);
TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr)); TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr));
std::vector<TfLiteDelegateParams*> partitions;
std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions( std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions(
params->max_delegated_partitions, params->min_nodes_per_partition, &partitions); params->max_delegated_partitions, params->min_nodes_per_partition);
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
"CoreML delegate: %d nodes delegated out of %d nodes, " "CoreML delegate: %d nodes delegated out of %d nodes, "
"with %d partitions.\n", "with %d partitions.\n",
delegated_nodes.size(), partition_helper.num_total_nodes(), partitions.size()); delegated_nodes.size(), partition_helper.num_total_nodes(),
partition_helper.num_partitions());
return context->ReplaceNodeSubsetsWithDelegateKernels( return context->ReplaceNodeSubsetsWithDelegateKernels(
context, GetCoreMlKernelRegistration(), delegates::BuildTfLiteIntArray(delegated_nodes).get(), context, GetCoreMlKernelRegistration(), BuildTfLiteIntArray(delegated_nodes).get(),
delegate); delegate);
} }

View File

@ -157,23 +157,12 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TfLiteHexagonDelegateOptions* params = TfLiteHexagonDelegateOptions* params =
static_cast<TfLiteHexagonDelegateOptions*>(delegate->data_); static_cast<TfLiteHexagonDelegateOptions*>(delegate->data_);
const auto delegate_partitions = helper.GetFirstNLargestPartitions( std::vector<int> supported_nodes = helper.GetNodesOfFirstNLargestPartitions(
params->max_delegated_partitions, params->min_nodes_per_partition); params->max_delegated_partitions, params->min_nodes_per_partition);
// To avoid creating a new TfLiteIntArray and free it later, we reserve one
// element to represent TfLiteIntArray.size which is the 1st element of
// TfLiteIntArray C struct.
std::vector<int> supported_nodes(1);
for (const auto partition : delegate_partitions) {
auto* nodes = partition->nodes_to_replace;
supported_nodes.insert(supported_nodes.end(), nodes->data,
nodes->data + nodes->size);
}
// Set first element to the number of nodes to replace.
supported_nodes[0] = supported_nodes.size() - 1;
auto* hexagon_delegate = static_cast<HexagonDelegate*>(delegate); auto* hexagon_delegate = static_cast<HexagonDelegate*>(delegate);
// Make sure dynamic batch is requested on fully delegated graph only. // Make sure dynamic batch is requested on fully delegated graph only.
if (supported_nodes[0] != helper.num_total_nodes() && if (supported_nodes.size() != helper.num_total_nodes() &&
hexagon_delegate != nullptr && hexagon_delegate != nullptr &&
hexagon_delegate->params()->enable_dynamic_batch_size) { hexagon_delegate->params()->enable_dynamic_batch_size) {
TF_LITE_KERNEL_LOG( TF_LITE_KERNEL_LOG(
@ -183,12 +172,12 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
"Hexagon delegate: %d nodes delegated out of %d nodes with " "Hexagon delegate: %d nodes delegated out of %d nodes with "
"%d partitions.\n", "%d partitions.\n",
supported_nodes[0], helper.num_total_nodes(), supported_nodes.size(), helper.num_total_nodes(),
delegate_partitions.size()); helper.num_partitions());
return context->ReplaceNodeSubsetsWithDelegateKernels( return context->ReplaceNodeSubsetsWithDelegateKernels(
context, GetHexagonKernelRegistration(), context, GetHexagonKernelRegistration(),
reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), delegate); BuildTfLiteIntArray(supported_nodes).get(), delegate);
} }
TfLiteDelegate* CreateDelegate(const TfLiteHexagonDelegateOptions* params) { TfLiteDelegate* CreateDelegate(const TfLiteHexagonDelegateOptions* params) {

View File

@ -38,6 +38,14 @@ bool IsFlexOp(const char* custom_name) {
strlen(kFlexCustomCodePrefix)) == 0; strlen(kFlexCustomCodePrefix)) == 0;
} }
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
const std::vector<int>& data) {
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
TfLiteIntArrayCreate(data.size()));
std::copy(data.begin(), data.end(), result->data);
return result;
}
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) { TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()), return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
input.data()); input.data());

View File

@ -21,6 +21,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_UTIL_H_ #ifndef TENSORFLOW_LITE_UTIL_H_
#define TENSORFLOW_LITE_UTIL_H_ #define TENSORFLOW_LITE_UTIL_H_
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -60,6 +61,11 @@ struct TfLiteIntArrayDeleter {
} }
}; };
// Helper for Building TfLiteIntArray that is wrapped in a unique_ptr,
// So that it is automatically freed when it goes out of the scope.
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
const std::vector<int>& data);
// Populates the size in bytes of a type into `bytes`. Returns kTfLiteOk for // Populates the size in bytes of a type into `bytes`. Returns kTfLiteOk for
// valid types, and kTfLiteError otherwise. // valid types, and kTfLiteError otherwise.
TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,