From 9278bbfc249233190c3901e9e6922d4077372899 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Sun, 22 Mar 2020 22:55:34 -0700 Subject: [PATCH] Support to pass in a list of supported node indices directly to GraphPartitionHelper. PiperOrigin-RevId: 302365857 Change-Id: I3debf09a5cc033b07e56b08e8345d548fb556cb7 --- tensorflow/lite/delegates/utils.cc | 4 ++-- tensorflow/lite/delegates/utils.h | 12 ++++++++--- tensorflow/lite/delegates/utils_test.cc | 27 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc index 75839d53560..fba8bec39a5 100644 --- a/tensorflow/lite/delegates/utils.cc +++ b/tensorflow/lite/delegates/utils.cc @@ -18,9 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" -#include "tensorflow/lite/util.h" namespace tflite { namespace delegates { @@ -98,6 +96,8 @@ GraphPartitionHelper::GetFirstNLargestPartitions( TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( std::set* unsupported_nodes_info) { + if (!is_node_supported_fn_) return kTfLiteOk; + TfLiteIntArray* execution_plan = nullptr; auto status = context_->GetExecutionPlan(context_, &execution_plan); if (status != kTfLiteOk) { diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h index f894cae30fd..d6d22c4efa2 100644 --- a/tensorflow/lite/delegates/utils.h +++ b/tensorflow/lite/delegates/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace delegates { @@ -43,12 +44,17 @@ using IsNodeSupportedFn = // Note the class *needs* to be used in TfLiteDelegate::Prepare. class GraphPartitionHelper { public: - // TODO(b/151152967): Support use-cases where a list of supported nodes are - // directly passed-in. GraphPartitionHelper(TfLiteContext* context, IsNodeSupportedFn is_node_supported_fn) : context_(context), is_node_supported_fn_(is_node_supported_fn) {} + GraphPartitionHelper(TfLiteContext* context, + const std::vector& supported_node_indices) + : context_(context), + num_total_nodes_(supported_node_indices.size()), + supported_nodes_( + ConvertVectorToTfLiteIntArray(supported_node_indices)) {} + virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); } // Partition the graph into node subsets such that each subset could be @@ -98,7 +104,7 @@ class GraphPartitionHelper { int num_total_nodes_ = 0; // Tells if a node is supported as it could be delegated. - const IsNodeSupportedFn is_node_supported_fn_; + const IsNodeSupportedFn is_node_supported_fn_ = nullptr; // Contains an array of supported node indices. TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory diff --git a/tensorflow/lite/delegates/utils_test.cc b/tensorflow/lite/delegates/utils_test.cc index a67778fee1f..5d308a0b546 100644 --- a/tensorflow/lite/delegates/utils_test.cc +++ b/tensorflow/lite/delegates/utils_test.cc @@ -223,6 +223,33 @@ TEST(GraphPartitionHelper, CheckPrepareErrors) { EXPECT_EQ(kTfLiteError, helper.Partition(nullptr)); } +TEST(GraphPartitionHelper, CheckPartitionsWithSupportedNodeList) { + // The mocked TfLiteContext has 4 partitions: {1}, {0,3,7,8}, {2,4,9}, {5,6}. + // So, we simply create a list of supported nodes as {0,1,2,...,8,9} + MockTfLiteContext mocked_context; + std::vector supported_nodes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + GraphPartitionHelper helper(&mocked_context, supported_nodes); + EXPECT_EQ(kTfLiteOk, helper.Partition(nullptr)); + EXPECT_EQ(10, helper.num_total_nodes()); + EXPECT_EQ(4, helper.num_partitions()); + + auto partitions = helper.GetFirstNLargestPartitions(1, 0); + EXPECT_EQ(1, partitions.size()); + auto nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8})); + + // Get the largest partition but requiring at least 5 nodes, so empty result. + partitions = helper.GetFirstNLargestPartitions(1, 5); + EXPECT_TRUE(partitions.empty()); + + partitions = helper.GetFirstNLargestPartitions(10, 3); + EXPECT_EQ(2, partitions.size()); + EXPECT_EQ(4, partitions[0]->nodes_to_replace->size); + EXPECT_EQ(3, partitions[1]->nodes_to_replace->size); + nodes = GetNodesToReplaceFromPartitions(partitions); + EXPECT_THAT(nodes, testing::ElementsAreArray({0, 3, 7, 8, 2, 4, 9})); +} + } // namespace } // namespace delegates } // namespace tflite