Support to pass in a list of supported node indices directly to GraphPartitionHelper.

PiperOrigin-RevId: 302365857
Change-Id: I3debf09a5cc033b07e56b08e8345d548fb556cb7
This commit is contained in:
Chao Mei 2020-03-22 22:55:34 -07:00 committed by TensorFlower Gardener
parent 7116a21f17
commit 9278bbfc24
3 changed files with 38 additions and 5 deletions

View File

@ -18,9 +18,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace delegates {
@ -98,6 +96,8 @@ GraphPartitionHelper::GetFirstNLargestPartitions(
TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
std::set<std::string>* unsupported_nodes_info) {
if (!is_node_supported_fn_) return kTfLiteOk;
TfLiteIntArray* execution_plan = nullptr;
auto status = context_->GetExecutionPlan(context_, &execution_plan);
if (status != kTfLiteOk) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace delegates {
@ -43,12 +44,17 @@ using IsNodeSupportedFn =
// Note the class *needs* to be used in TfLiteDelegate::Prepare.
class GraphPartitionHelper {
public:
// TODO(b/151152967): Support use-cases where a list of supported nodes are
// directly passed-in.
GraphPartitionHelper(TfLiteContext* context,
IsNodeSupportedFn is_node_supported_fn)
: context_(context), is_node_supported_fn_(is_node_supported_fn) {}
GraphPartitionHelper(TfLiteContext* context,
const std::vector<int>& supported_node_indices)
: context_(context),
num_total_nodes_(supported_node_indices.size()),
supported_nodes_(
ConvertVectorToTfLiteIntArray(supported_node_indices)) {}
virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
// Partition the graph into node subsets such that each subset could be
@ -98,7 +104,7 @@ class GraphPartitionHelper {
int num_total_nodes_ = 0;
// Tells if a node is supported as it could be delegated.
const IsNodeSupportedFn is_node_supported_fn_;
const IsNodeSupportedFn is_node_supported_fn_ = nullptr;
// Contains an array of supported node indices.
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory

View File

@ -223,6 +223,33 @@ TEST(GraphPartitionHelper, CheckPrepareErrors) {
EXPECT_EQ(kTfLiteError, helper.Partition(nullptr));
}
TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) {
// The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}.
// So, we simply create a list of supported nodes as {0,1,2,...,8,9}
MockTfLiteContext mocked_context;
std::vector<int> supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
GraphPartitionHelper helper(&mocked_context, supported_nodes);
EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr));
EXPECT_EQ(10, helper.num_total_nodes());
EXPECT_EQ(4, helper.num_partitions());
auto partitions = helper.GetFirstNLargestPartitions(1, 0);
EXPECT_EQ(1, partitions.size());
auto nodes = GetNodesToReplaceFromPartitions(partitions);
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8}));
// Get the largest partition but requiring at least 5 nodes, so empty result.
partitions = helper.GetFirstNLargestPartitions(1, 5);
EXPECT_TRUE(partitions.empty());
partitions = helper.GetFirstNLargestPartitions(10, 3);
EXPECT_EQ(2, partitions.size());
EXPECT_EQ(4, partitions[0]->nodes_to_replace->size);
EXPECT_EQ(3, partitions[1]->nodes_to_replace->size);
nodes = GetNodesToReplaceFromPartitions(partitions);
EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9}));
}
} // namespace
} // namespace delegates
} // namespace tflite