Adds util to select top node subsets (by length) for delegation.
PiperOrigin-RevId: 276285931 Change-Id: I1e3b6691cc73237f4b5f0f6b177b2a2b5925c5cc
This commit is contained in:
parent
02acae41bd
commit
8929466f21
43
tensorflow/lite/delegates/BUILD
Normal file
43
tensorflow/lite/delegates/BUILD
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
hdrs = ["utils.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
linkopts = tflite_linkopts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
70
tensorflow/lite/delegates/utils.cc
Normal file
70
tensorflow/lite/delegates/utils.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
|
||||
TfLiteStatus PruneContinuousSubsets(TfLiteContext* context,
|
||||
const int max_subsets,
|
||||
std::vector<int>* indices) {
|
||||
if (!indices) {
|
||||
context->ReportError(context, "indices cannot be nullptr");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices->empty() || indices->size() < max_subsets) return kTfLiteOk;
|
||||
|
||||
// Sort indices just in case.
|
||||
std::sort(indices->begin(), indices->end());
|
||||
|
||||
// Build a vector of subsets.
|
||||
std::vector<std::vector<int>> continuous_subsets;
|
||||
int last_index = indices->at(0) - 2;
|
||||
for (const auto idx : *indices) {
|
||||
if (idx > last_index + 1) {
|
||||
continuous_subsets.emplace_back();
|
||||
}
|
||||
continuous_subsets.back().push_back(idx);
|
||||
last_index = idx;
|
||||
}
|
||||
|
||||
// Nothing to be done if number of subsets is already less than max_subsets.
|
||||
if (continuous_subsets.size() <= max_subsets) return kTfLiteOk;
|
||||
|
||||
// Sort the vector of subsets in descending order of length.
|
||||
std::sort(continuous_subsets.begin(), continuous_subsets.end(),
|
||||
[](const std::vector<int>& a, const std::vector<int>& b) {
|
||||
return a.size() > b.size();
|
||||
});
|
||||
|
||||
// Re-build indices vector from top subsets.
|
||||
indices->clear();
|
||||
for (int i = 0; i < max_subsets; ++i) {
|
||||
indices->reserve(indices->size() + continuous_subsets[i].size());
|
||||
indices->insert(indices->end(), continuous_subsets[i].begin(),
|
||||
continuous_subsets[i].end());
|
||||
}
|
||||
std::sort(indices->begin(), indices->end());
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
39
tensorflow/lite/delegates/utils.h
Normal file
39
tensorflow/lite/delegates/utils.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
|
||||
// Given a list(vector<int>) of indices, modifies it in-place to contain
|
||||
// max_subsets number of continuous subsets. Subsets are selected in descending
|
||||
// order of their length.
|
||||
// Resulting vector contains sorted list of pruned indices.
|
||||
//
|
||||
// This util can be used by delegates to avoid accepting too many node-subsets.
|
||||
TfLiteStatus PruneContinuousSubsets(TfLiteContext* context,
|
||||
const int max_subsets,
|
||||
std::vector<int>* indices);
|
||||
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_UTILS_H_
|
106
tensorflow/lite/delegates/utils_test.cc
Normal file
106
tensorflow/lite/delegates/utils_test.cc
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
void ReportError(TfLiteContext* context, const char* format, ...) {}
|
||||
|
||||
TEST(UtilsTest, PruneContinuousSubsets_NoSubsets) {
|
||||
TfLiteContext context;
|
||||
context.ReportError = ReportError;
|
||||
std::vector<int> original_indices = {};
|
||||
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 5, nullptr), kTfLiteError);
|
||||
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &original_indices), kTfLiteOk);
|
||||
ASSERT_TRUE(original_indices.empty());
|
||||
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 2, &original_indices), kTfLiteOk);
|
||||
ASSERT_TRUE(original_indices.empty());
|
||||
}
|
||||
|
||||
TEST(UtilsTest, PruneContinuousSubsets_SingleSubset) {
|
||||
TfLiteContext context;
|
||||
std::vector<int> original_indices = {0, 1, 2, 3};
|
||||
|
||||
std::vector<int> indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({0, 1, 2, 3}));
|
||||
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||
ASSERT_TRUE(indices.empty());
|
||||
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 2, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({0, 1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(UtilsTest, PruneContinuousSubsets_MultipleSubsets) {
|
||||
TfLiteContext context;
|
||||
// 5 subsets: (0, 1), (3, 4, 5), (7), (10, 11), (19).
|
||||
std::vector<int> original_indices = {0, 1, 3, 4, 5, 7, 10, 11, 19};
|
||||
|
||||
std::vector<int> indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 4, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11}));
|
||||
|
||||
// Only the longest subset is selected.
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({3, 4, 5}));
|
||||
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||
ASSERT_TRUE(indices.empty());
|
||||
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 1000, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11, 19}));
|
||||
}
|
||||
|
||||
TEST(UtilsTest, PruneContinuousSubsets_UnsortedIndices) {
|
||||
TfLiteContext context;
|
||||
// 5 subsets: (0, 1), (3, 4, 5), (7), (10, 11), (19).
|
||||
std::vector<int> original_indices = {5, 7, 4, 10, 11, 19, 0, 1, 3};
|
||||
|
||||
std::vector<int> indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 4, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({0, 1, 3, 4, 5, 7, 10, 11}));
|
||||
|
||||
// Only the longest subset is selected.
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 1, &indices), kTfLiteOk);
|
||||
EXPECT_THAT(indices, ElementsAreArray({3, 4, 5}));
|
||||
|
||||
indices = original_indices;
|
||||
ASSERT_EQ(PruneContinuousSubsets(&context, 0, &indices), kTfLiteOk);
|
||||
ASSERT_TRUE(indices.empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user