Support to pass in a list of supported node indices directly to GraphPartitionHelper.
PiperOrigin-RevId: 302365857 Change-Id: I3debf09a5cc033b07e56b08e8345d548fb556cb7
This commit is contained in:
parent
7116a21f17
commit
9278bbfc24
@ -18,9 +18,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
@ -98,6 +96,8 @@ GraphPartitionHelper::GetFirstNLargestPartitions(
|
||||
|
||||
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
||||
std::set<std::string>* unsupported_nodes_info) {
|
||||
if (!is_node_supported_fn_) return kTfLiteOk;
|
||||
|
||||
TfLiteIntArray* execution_plan = nullptr;
|
||||
auto status = context_->GetExecutionPlan(context_, &execution_plan);
|
||||
if (status != kTfLiteOk) {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
@ -43,12 +44,17 @@ using IsNodeSupportedFn =
|
||||
// Note the class *needs* to be used in TfLiteDelegate::Prepare.
|
||||
class GraphPartitionHelper {
|
||||
public:
|
||||
// TODO(b/151152967): Support use-cases where a list of supported nodes are
|
||||
// directly passed-in.
|
||||
GraphPartitionHelper(TfLiteContext* context,
|
||||
IsNodeSupportedFn is_node_supported_fn)
|
||||
: context_(context), is_node_supported_fn_(is_node_supported_fn) {}
|
||||
|
||||
GraphPartitionHelper(TfLiteContext* context,
|
||||
const std::vector<int>& supported_node_indices)
|
||||
: context_(context),
|
||||
num_total_nodes_(supported_node_indices.size()),
|
||||
supported_nodes_(
|
||||
ConvertVectorToTfLiteIntArray(supported_node_indices)) {}
|
||||
|
||||
virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
|
||||
|
||||
// Partition the graph into node subsets such that each subset could be
|
||||
@ -98,7 +104,7 @@ class GraphPartitionHelper {
|
||||
int num_total_nodes_ = 0;
|
||||
|
||||
// Tells if a node is supported as it could be delegated.
|
||||
const IsNodeSupportedFn is_node_supported_fn_;
|
||||
const IsNodeSupportedFn is_node_supported_fn_ = nullptr;
|
||||
|
||||
// Contains an array of supported node indices.
|
||||
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
||||
|
@ -223,6 +223,33 @@ TEST(GraphPartitionHelper, CheckPrepareErrors) {
|
||||
EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
|
||||
}
|
||||
|
||||
TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) {
|
||||
// The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
|
||||
// So, we simply create a list of supported nodes as {0,1,2,...,8,9}
|
||||
MockTfLiteContext mocked_context;
|
||||
std::vector<int> supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
GraphPartitionHelper helper(&mocked_context, supported_nodes);
|
||||
EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
|
||||
EXPECT_EQ(10, helper.num_total_nodes());
|
||||
EXPECT_EQ(4, helper.num_partitions());
|
||||
|
||||
auto partitions = helper.GetFirstNLargestPartitions(1, 0);
|
||||
EXPECT_EQ(1, partitions.size());
|
||||
auto nodes = GetNodesToReplaceFromPartitions(partitions);
|
||||
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
|
||||
|
||||
// Get the largest partition but requiring at least 5 nodes, so empty result.
|
||||
partitions = helper.GetFirstNLargestPartitions(1, 5);
|
||||
EXPECT_TRUE(partitions.empty());
|
||||
|
||||
partitions = helper.GetFirstNLargestPartitions(10, 3);
|
||||
EXPECT_EQ(2, partitions.size());
|
||||
EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
|
||||
EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
|
||||
nodes = GetNodesToReplaceFromPartitions(partitions);
|
||||
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user