Support to pass in a list of supported node indices directly to GraphPartitionHelper.
PiperOrigin-RevId: 302365857 Change-Id: I3debf09a5cc033b07e56b08e8345d548fb556cb7
This commit is contained in:
parent
7116a21f17
commit
9278bbfc24
@ -18,9 +18,7 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace delegates {
|
namespace delegates {
|
||||||
@ -98,6 +96,8 @@ GraphPartitionHelper::GetFirstNLargestPartitions(
|
|||||||
|
|
||||||
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
||||||
std::set<std::string>* unsupported_nodes_info) {
|
std::set<std::string>* unsupported_nodes_info) {
|
||||||
|
if (!is_node_supported_fn_) return kTfLiteOk;
|
||||||
|
|
||||||
TfLiteIntArray* execution_plan = nullptr;
|
TfLiteIntArray* execution_plan = nullptr;
|
||||||
auto status = context_->GetExecutionPlan(context_, &execution_plan);
|
auto status = context_->GetExecutionPlan(context_, &execution_plan);
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace delegates {
|
namespace delegates {
|
||||||
@ -43,12 +44,17 @@ using IsNodeSupportedFn =
|
|||||||
// Note the class *needs* to be used in TfLiteDelegate::Prepare.
|
// Note the class *needs* to be used in TfLiteDelegate::Prepare.
|
||||||
class GraphPartitionHelper {
|
class GraphPartitionHelper {
|
||||||
public:
|
public:
|
||||||
// TODO(b/151152967): Support use-cases where a list of supported nodes are
|
|
||||||
// directly passed-in.
|
|
||||||
GraphPartitionHelper(TfLiteContext* context,
|
GraphPartitionHelper(TfLiteContext* context,
|
||||||
IsNodeSupportedFn is_node_supported_fn)
|
IsNodeSupportedFn is_node_supported_fn)
|
||||||
: context_(context), is_node_supported_fn_(is_node_supported_fn) {}
|
: context_(context), is_node_supported_fn_(is_node_supported_fn) {}
|
||||||
|
|
||||||
|
GraphPartitionHelper(TfLiteContext* context,
|
||||||
|
const std::vector<int>& supported_node_indices)
|
||||||
|
: context_(context),
|
||||||
|
num_total_nodes_(supported_node_indices.size()),
|
||||||
|
supported_nodes_(
|
||||||
|
ConvertVectorToTfLiteIntArray(supported_node_indices)) {}
|
||||||
|
|
||||||
virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
|
virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
|
||||||
|
|
||||||
// Partition the graph into node subsets such that each subset could be
|
// Partition the graph into node subsets such that each subset could be
|
||||||
@ -98,7 +104,7 @@ class GraphPartitionHelper {
|
|||||||
int num_total_nodes_ = 0;
|
int num_total_nodes_ = 0;
|
||||||
|
|
||||||
// Tells if a node is supported as it could be delegated.
|
// Tells if a node is supported as it could be delegated.
|
||||||
const IsNodeSupportedFn is_node_supported_fn_;
|
const IsNodeSupportedFn is_node_supported_fn_ = nullptr;
|
||||||
|
|
||||||
// Contains an array of supported node indices.
|
// Contains an array of supported node indices.
|
||||||
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
||||||
|
@ -223,6 +223,33 @@ TEST(GraphPartitionHelper, CheckPrepareErrors) {
|
|||||||
EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
|
EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) {
|
||||||
|
// The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
|
||||||
|
// So, we simply create a list of supported nodes as {0,1,2,...,8,9}
|
||||||
|
MockTfLiteContext mocked_context;
|
||||||
|
std::vector<int> supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||||
|
GraphPartitionHelper helper(&mocked_context, supported_nodes);
|
||||||
|
EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
|
||||||
|
EXPECT_EQ(10, helper.num_total_nodes());
|
||||||
|
EXPECT_EQ(4, helper.num_partitions());
|
||||||
|
|
||||||
|
auto partitions = helper.GetFirstNLargestPartitions(1, 0);
|
||||||
|
EXPECT_EQ(1, partitions.size());
|
||||||
|
auto nodes = GetNodesToReplaceFromPartitions(partitions);
|
||||||
|
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
|
||||||
|
|
||||||
|
// Get the largest partition but requiring at least 5 nodes, so empty result.
|
||||||
|
partitions = helper.GetFirstNLargestPartitions(1, 5);
|
||||||
|
EXPECT_TRUE(partitions.empty());
|
||||||
|
|
||||||
|
partitions = helper.GetFirstNLargestPartitions(10, 3);
|
||||||
|
EXPECT_EQ(2, partitions.size());
|
||||||
|
EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
|
||||||
|
EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
|
||||||
|
nodes = GetNodesToReplaceFromPartitions(partitions);
|
||||||
|
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace delegates
|
} // namespace delegates
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user