Adds util to select top node subsets (by length) for delegation.
PiperOrigin-RevId: 276285931 Change-Id: I1e3b6691cc73237f4b5f0f6b177b2a2b5925c5cc
This commit is contained in:
parent
02acae41bd
commit
8929466f21
43
tensorflow/lite/delegates/BUILD
Normal file
43
tensorflow/lite/delegates/BUILD
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utils",
|
||||||
|
srcs = ["utils.cc"],
|
||||||
|
hdrs = ["utils.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "utils_test",
|
||||||
|
srcs = ["utils_test.cc"],
|
||||||
|
linkopts = tflite_linkopts(),
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
70
tensorflow/lite/delegates/utils.cc
Normal file
70
tensorflow/lite/delegates/utils.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
|
||||||
|
TfLiteStatus PruneContinuousSubsets(TfLiteContext* context,
|
||||||
|
const int max_subsets,
|
||||||
|
std::vector<int>* indices) {
|
||||||
|
if (!indices) {
|
||||||
|
context->ReportError(context, "indices cannot be nullptr");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (indices->empty() || indices->size() < max_subsets) return kTfLiteOk;
|
||||||
|
|
||||||
|
// Sort indices just in case.
|
||||||
|
std::sort(indices->begin(), indices->end());
|
||||||
|
|
||||||
|
// Build a vector of subsets.
|
||||||
|
std::vector<std::vector<int>> continuous_subsets;
|
||||||
|
int last_index = indices->at(0) - 2;
|
||||||
|
for (const auto idx : *indices) {
|
||||||
|
if (idx > last_index + 1) {
|
||||||
|
continuous_subsets.emplace_back();
|
||||||
|
}
|
||||||
|
continuous_subsets.back().push_back(idx);
|
||||||
|
last_index = idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing to be done if number of subsets is already less than max_subsets.
|
||||||
|
if (continuous_subsets.size() <= max_subsets) return kTfLiteOk;
|
||||||
|
|
||||||
|
// Sort the vector of subsets in descending order of length.
|
||||||
|
std::sort(continuous_subsets.begin(), continuous_subsets.end(),
|
||||||
|
[](const std::vector<int>& a, const std::vector<int>& b) {
|
||||||
|
return a.size() > b.size();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Re-build indices vector from top subsets.
|
||||||
|
indices->clear();
|
||||||
|
for (int i = 0; i < max_subsets; ++i) {
|
||||||
|
indices->reserve(indices->size() + continuous_subsets[i].size());
|
||||||
|
indices->insert(indices->end(), continuous_subsets[i].begin(),
|
||||||
|
continuous_subsets[i].end());
|
||||||
|
}
|
||||||
|
std::sort(indices->begin(), indices->end());
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
39
tensorflow/lite/delegates/utils.h
Normal file
39
tensorflow/lite/delegates/utils.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
|
||||||
|
// Given a list(vector<int>) of indices, modifies it in-place to contain
|
||||||
|
// max_subsets number of continuous subsets. Subsets are selected in descending
|
||||||
|
// order of their length.
|
||||||
|
// Resulting vector contains sorted list of pruned indices.
|
||||||
|
//
|
||||||
|
// This util can be used by delegates to avoid accepting too many node-subsets.
|
||||||
|
TfLiteStatus PruneContinuousSubsets(TfLiteContext* context,
|
||||||
|
const int max_subsets,
|
||||||
|
std::vector<int>* indices);
|
||||||
|
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
106
tensorflow/lite/delegates/utils_test.cc
Normal file
106
tensorflow/lite/delegates/utils_test.cc
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/utils.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
void ReportError(TfLiteContext* context, const char* format, ...) {}
|
||||||
|
|
||||||
|
TEST(UtilsTest, PruneContinuousSubsets_NoSubsets) {
|
||||||
|
TfLiteContext context;
|
||||||
|
context.ReportError = ReportError;
|
||||||
|
std::vector<int> original_indices = {};
|
||||||
|
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 5, nullptr), kTfLiteError);
|
||||||
|
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &original_indices), kTfLiteOk);
|
||||||
|
ASSERT_TRUE(original_indices.empty());
|
||||||
|
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 2, &original_indices), kTfLiteOk);
|
||||||
|
ASSERT_TRUE(original_indices.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilsTest, PruneContinuousSubsets_SingleSubset) {
|
||||||
|
TfLiteContext context;
|
||||||
|
std::vector<int> original_indices = {0, 1, 2, 3};
|
||||||
|
|
||||||
|
std::vector<int> indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({0, 1, 2, 3}));
|
||||||
|
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||||
|
ASSERT_TRUE(indices.empty());
|
||||||
|
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 2, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({0, 1, 2, 3}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilsTest, PruneContinuousSubsets_MultipleSubsets) {
|
||||||
|
TfLiteContext context;
|
||||||
|
// 5 subsets: (0, 1), (3, 4, 5), (7), (10, 11), (19).
|
||||||
|
std::vector<int> original_indices = {0, 1, 3, 4, 5, 7, 10, 11, 19};
|
||||||
|
|
||||||
|
std::vector<int> indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 4, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11}));
|
||||||
|
|
||||||
|
// Only the longest subset is selected.
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({3, 4, 5}));
|
||||||
|
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||||
|
ASSERT_TRUE(indices.empty());
|
||||||
|
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 1000, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11, 19}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilsTest, PruneContinuousSubsets_UnsortedIndices) {
|
||||||
|
TfLiteContext context;
|
||||||
|
// 5 subsets: (0, 1), (3, 4, 5), (7), (10, 11), (19).
|
||||||
|
std::vector<int> original_indices = {5, 7, 4, 10, 11, 19, 0, 1, 3};
|
||||||
|
|
||||||
|
std::vector<int> indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 4, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11}));
|
||||||
|
|
||||||
|
// Only the longest subset is selected.
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||||
|
EXPECT_THAT(indices, ElementsAreArray({3, 4, 5}));
|
||||||
|
|
||||||
|
indices = original_indices;
|
||||||
|
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||||
|
ASSERT_TRUE(indices.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user