Adds support for assertions on acceleration to kernel tests
PiperOrigin-RevId: 266190376
This commit is contained in:
parent
062f659524
commit
0e33ce3a42
@ -38,6 +38,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "acceleration_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["acceleration_test_util.cc"],
|
||||
hdrs = ["acceleration_test_util.h"],
|
||||
deps = [
|
||||
":nnapi_delegate",
|
||||
"//tensorflow/lite/kernels:acceleration_test_util_internal",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "nnapi_delegate_test",
|
||||
size = "small",
|
||||
|
38
tensorflow/lite/delegates/nnapi/acceleration_test_util.cc
Normal file
38
tensorflow/lite/delegates/nnapi/acceleration_test_util.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
||||
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
absl::optional<NnapiAccelerationTestParams> GetNnapiAccelerationTestParam(
|
||||
std::string test_id) {
|
||||
return GetAccelerationTestParam<NnapiAccelerationTestParams>(test_id);
|
||||
}
|
||||
|
||||
// static
|
||||
NnapiAccelerationTestParams NnapiAccelerationTestParams::ParseConfigurationLine(
|
||||
const std::string& conf_line) {
|
||||
if (conf_line.empty()) {
|
||||
return NnapiAccelerationTestParams();
|
||||
}
|
||||
|
||||
int min_sdk_version = std::stoi(conf_line);
|
||||
|
||||
return NnapiAccelerationTestParams{min_sdk_version};
|
||||
}
|
||||
|
||||
} // namespace tflite
|
51
tensorflow/lite/delegates/nnapi/acceleration_test_util.h
Normal file
51
tensorflow/lite/delegates/nnapi/acceleration_test_util.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_ACCELERATION_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_NNAPI_ACCELERATION_TEST_UTIL_H_
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// NNAPI specific configuration for the validation whitelist.
|
||||
class NnapiAccelerationTestParams {
|
||||
public:
|
||||
static constexpr const char* const kAccelerationTestConfig = "";
|
||||
|
||||
static NnapiAccelerationTestParams ParseConfigurationLine(
|
||||
const std::string& conf_line);
|
||||
|
||||
explicit NnapiAccelerationTestParams(int min_android_sdk_version)
|
||||
: min_android_sdk_version_{min_android_sdk_version} {};
|
||||
|
||||
NnapiAccelerationTestParams()
|
||||
: min_android_sdk_version_{delegate::nnapi::kMinSdkVersionForNNAPI} {};
|
||||
|
||||
// Minimum SDK version to apply the acceleration validation to.
|
||||
int MinAndroidSdkVersion() { return min_android_sdk_version_; }
|
||||
|
||||
private:
|
||||
int min_android_sdk_version_;
|
||||
};
|
||||
|
||||
// Returns the NNAPI acceleration test configuration for the given test id.
|
||||
absl::optional<NnapiAccelerationTestParams> GetNnapiAccelerationTestParam(
|
||||
std::string test_id);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_NNAPI_ACCELERATION_TEST_UTIL_H_
|
@ -119,20 +119,65 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "acceleration_test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"acceleration_test_util.cc",
|
||||
],
|
||||
hdrs = ["acceleration_test_util.h"],
|
||||
deps = [
|
||||
":acceleration_test_util_internal",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "acceleration_test_util_internal",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"acceleration_test_util_internal.cc",
|
||||
],
|
||||
hdrs = ["acceleration_test_util_internal.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "acceleration_test_util_internal_test",
|
||||
srcs = [
|
||||
"acceleration_test_util_internal_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":acceleration_test_util_internal",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.cc"],
|
||||
hdrs = ["test_util.h"],
|
||||
deps = [
|
||||
":acceleration_test_util",
|
||||
":builtin_ops",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/delegates/nnapi:acceleration_test_util",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
44
tensorflow/lite/kernels/acceleration_test_util.cc
Normal file
44
tensorflow/lite/kernels/acceleration_test_util.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cctype>
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util_internal.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
std::string GetCurrentTestId() {
|
||||
const ::testing::TestInfo* const test_info =
|
||||
::testing::UnitTest::GetInstance()->current_test_info();
|
||||
|
||||
std::stringstream test_id_stream;
|
||||
|
||||
test_id_stream << test_info->test_suite_name() << "/" << test_info->name();
|
||||
|
||||
return test_id_stream.str();
|
||||
}
|
||||
|
||||
} // namespace tflite
|
28
tensorflow/lite/kernels/acceleration_test_util.h
Normal file
28
tensorflow/lite/kernels/acceleration_test_util.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_H_
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Returns the test id to use to retrieve the acceleration configuration
|
||||
// in the acceleration whitelist.
|
||||
std::string GetCurrentTestId();
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_H_
|
60
tensorflow/lite/kernels/acceleration_test_util_internal.cc
Normal file
60
tensorflow/lite/kernels/acceleration_test_util_internal.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
void ReadAccelerationConfig(
|
||||
const char* config,
|
||||
const std::function<void(std::string, std::string, bool)>& consumer) {
|
||||
if (config) {
|
||||
std::istringstream istream{config};
|
||||
|
||||
std::string curr_config_line;
|
||||
while (std::getline(istream, curr_config_line)) {
|
||||
// trim whitespaces
|
||||
curr_config_line.erase(
|
||||
curr_config_line.begin(),
|
||||
std::find_if_not(curr_config_line.begin(), curr_config_line.end(),
|
||||
[](int ch) { return std::isspace(ch); }));
|
||||
// skipping comments and empty lines.
|
||||
if (curr_config_line.empty() || curr_config_line.at(0) == '#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// split in test id regexp and rest of the config.
|
||||
auto first_sep_pos =
|
||||
std::find(curr_config_line.begin(), curr_config_line.end(), ',');
|
||||
|
||||
bool is_blacklist = false;
|
||||
std::string key = curr_config_line;
|
||||
std::string value{};
|
||||
if (first_sep_pos != curr_config_line.end()) {
|
||||
key = std::string(curr_config_line.begin(), first_sep_pos);
|
||||
value = std::string(first_sep_pos + 1, curr_config_line.end());
|
||||
}
|
||||
|
||||
// Regexps starting with '-'' are blacklist ones.
|
||||
if (key[0] == '-') {
|
||||
key = key.substr(1);
|
||||
is_blacklist = true;
|
||||
}
|
||||
|
||||
consumer(key, value, is_blacklist);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
104
tensorflow/lite/kernels/acceleration_test_util_internal.h
Normal file
104
tensorflow/lite/kernels/acceleration_test_util_internal.h
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "re2/re2.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Reads the acceleration configuration, handles comments and empty lines and
|
||||
// the basic data conversion format (split into key, value, recognition of
|
||||
// the line being a white or black list entry) and gives the data to the
|
||||
// consumer to be inserted into the target collection.
|
||||
void ReadAccelerationConfig(
|
||||
const char* config,
|
||||
const std::function<void(std::string, std::string, bool)>& consumer);
|
||||
|
||||
template <typename T>
|
||||
class ConfigurationEntry {
|
||||
public:
|
||||
ConfigurationEntry(const std::string& test_id_rex, T test_config,
|
||||
bool is_blacklist)
|
||||
: test_id_rex_(test_id_rex),
|
||||
test_config_(test_config),
|
||||
is_blacklist_(is_blacklist) {}
|
||||
|
||||
bool Matches(const std::string& test_id) {
|
||||
return RE2::FullMatch(test_id, test_id_rex_);
|
||||
}
|
||||
bool IsBlacklistEntry() const { return is_blacklist_; }
|
||||
const T& TestConfig() const { return test_config_; }
|
||||
|
||||
const std::string& TestIdRex() const { return test_id_rex_; }
|
||||
|
||||
private:
|
||||
std::string test_id_rex_;
|
||||
T test_config_;
|
||||
bool is_blacklist_;
|
||||
};
|
||||
|
||||
// Returns the acceleration test configuration for the given test id and
|
||||
// the given acceleration configuration type.
|
||||
// The configuration type is responsible of providing the test configuration
|
||||
// and the parse function to convert configuration lines into configuration
|
||||
// objects.
|
||||
template <typename T>
|
||||
absl::optional<T> GetAccelerationTestParam(std::string test_id) {
|
||||
static std::atomic<std::vector<ConfigurationEntry<T>>*> test_config_ptr;
|
||||
|
||||
if (test_config_ptr.load() == nullptr) {
|
||||
auto config = new std::vector<ConfigurationEntry<T>>();
|
||||
|
||||
auto consumer = [&config](std::string key, std::string value_str,
|
||||
bool is_blacklist) mutable {
|
||||
T value = T::ParseConfigurationLine(value_str);
|
||||
config->push_back(ConfigurationEntry<T>(key, value, is_blacklist));
|
||||
};
|
||||
|
||||
ReadAccelerationConfig(T::kAccelerationTestConfig, consumer);
|
||||
|
||||
// Even if it has been already set, it would be just replaced with the
|
||||
// same value, just freeing the old value to avoid leaks
|
||||
auto* prev_val = test_config_ptr.exchange(config);
|
||||
delete prev_val;
|
||||
}
|
||||
|
||||
const std::vector<ConfigurationEntry<T>>* test_config =
|
||||
test_config_ptr.load();
|
||||
|
||||
const auto test_config_iter = std::find_if(
|
||||
test_config->begin(), test_config->end(),
|
||||
[&test_id](ConfigurationEntry<T> elem) { return elem.Matches(test_id); });
|
||||
if (test_config_iter != test_config->end() &&
|
||||
!test_config_iter->IsBlacklistEntry()) {
|
||||
return absl::optional<T>(test_config_iter->TestConfig());
|
||||
} else {
|
||||
return absl::optional<T>();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
|
201
tensorflow/lite/kernels/acceleration_test_util_internal_test.cc
Normal file
201
tensorflow/lite/kernels/acceleration_test_util_internal_test.cc
Normal file
@ -0,0 +1,201 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util_internal.h"
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
using ::testing::Eq;
|
||||
using ::testing::Not;
|
||||
using ::testing::Test;
|
||||
|
||||
struct SimpleConfig {
|
||||
public:
|
||||
static constexpr const char* kAccelerationTestConfig =
|
||||
R"(
|
||||
#test-id,some-other-data
|
||||
test-1,data-1
|
||||
test-2,
|
||||
test-3,data-3
|
||||
test-4.*,data-4
|
||||
-test-5
|
||||
test-6
|
||||
test-7,data-7
|
||||
)";
|
||||
|
||||
static SimpleConfig ParseConfigurationLine(const std::string& conf_line) {
|
||||
return {conf_line};
|
||||
}
|
||||
|
||||
std::string value;
|
||||
};
|
||||
|
||||
class ReadAccelerationConfigTest : public ::testing::Test {
|
||||
public:
|
||||
std::unordered_map<std::string, SimpleConfig> whitelist_;
|
||||
std::unordered_map<std::string, SimpleConfig> blacklist_;
|
||||
std::function<void(std::string, std::string, bool)> consumer_ =
|
||||
[this](std::string key, std::string value, bool is_blacklist) {
|
||||
if (is_blacklist) {
|
||||
blacklist_[key] = {value};
|
||||
} else {
|
||||
whitelist_[key] = {value};
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyOnlyLine) {
|
||||
ReadAccelerationConfig("key", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_.find("key"), Not(Eq(whitelist_.end())));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsABlacklistKeyOnlyLine) {
|
||||
ReadAccelerationConfig("-key", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_.find("key"), Not(Eq(whitelist_.end())));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyValueLine) {
|
||||
ReadAccelerationConfig("key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsABlackListKeyValueLine) {
|
||||
ReadAccelerationConfig("-key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, KeysAreLeftTrimmed) {
|
||||
ReadAccelerationConfig(" key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, BlKeysAreLeftTrimmed) {
|
||||
ReadAccelerationConfig(" -key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) {
|
||||
ReadAccelerationConfig("#key,value", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, CommentCanHaveTralingBlanks) {
|
||||
ReadAccelerationConfig(" #key,value", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, CommentsAreOnlyForTheFullLine) {
|
||||
ReadAccelerationConfig("key,value #comment", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value #comment"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, IgnoresEmptyLines) {
|
||||
ReadAccelerationConfig("", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLines) {
|
||||
ReadAccelerationConfig("key1,value1\nkey2,value2\n-key3,value3", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(blacklist_["key3"].value, Eq("value3"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithCommentsAndSpaces) {
|
||||
ReadAccelerationConfig("key1,value1\n#comment\n\nkey2,value2", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithMissingConfigValues) {
|
||||
ReadAccelerationConfig("key1\nkey2,value2\nkey3\nkey4,value4", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq(""));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(whitelist_["key3"].value, Eq(""));
|
||||
EXPECT_THAT(whitelist_["key4"].value, Eq("value4"));
|
||||
}
|
||||
|
||||
TEST(GetAccelerationTestParam, LoadsTestConfig) {
|
||||
const auto config_value_maybe =
|
||||
GetAccelerationTestParam<SimpleConfig>("test-3");
|
||||
ASSERT_TRUE(config_value_maybe.has_value());
|
||||
ASSERT_THAT(config_value_maybe.value().value, Eq("data-3"));
|
||||
}
|
||||
|
||||
TEST(GetAccelerationTestParam, LoadsTestConfigWithEmptyValue) {
|
||||
const auto config_value_maybe =
|
||||
GetAccelerationTestParam<SimpleConfig>("test-2");
|
||||
ASSERT_TRUE(config_value_maybe.has_value());
|
||||
ASSERT_THAT(config_value_maybe.value().value, Eq(""));
|
||||
}
|
||||
|
||||
TEST(GetAccelerationTestParam, SupportsWildcards) {
|
||||
const auto config_value_maybe =
|
||||
GetAccelerationTestParam<SimpleConfig>("test-41");
|
||||
ASSERT_TRUE(config_value_maybe.has_value());
|
||||
ASSERT_THAT(config_value_maybe.value().value, Eq("data-4"));
|
||||
}
|
||||
|
||||
TEST(GetAccelerationTestParam, SupportBlacklist) {
|
||||
const auto config_value_maybe =
|
||||
GetAccelerationTestParam<SimpleConfig>("test-5");
|
||||
ASSERT_FALSE(config_value_maybe.has_value());
|
||||
}
|
||||
|
||||
struct UnmatchedSimpleConfig {
|
||||
public:
|
||||
static constexpr const char* kAccelerationTestConfig = nullptr;
|
||||
|
||||
static UnmatchedSimpleConfig ParseConfigurationLine(
|
||||
const std::string& conf_line) {
|
||||
return {conf_line};
|
||||
}
|
||||
|
||||
std::string value;
|
||||
};
|
||||
|
||||
TEST(GetAccelerationTestParam, ReturnEmptyOptionalForNullConfig) {
|
||||
ASSERT_FALSE(
|
||||
GetAccelerationTestParam<UnmatchedSimpleConfig>("test-3").has_value());
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -14,8 +14,18 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -236,4 +246,62 @@ std::vector<string> SingleOpModel::ExtractVector(int index) const {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the number of partitions associated, as result of a call to
|
||||
// ModifyGraphWithDelegate, to the given delegate.
|
||||
int CountPartitionsDelegatedTo(Subgraph* subgraph,
|
||||
const TfLiteDelegate* delegate) {
|
||||
return std::count_if(
|
||||
subgraph->nodes_and_registration().begin(),
|
||||
subgraph->nodes_and_registration().end(),
|
||||
[delegate](
|
||||
std::pair<TfLiteNode, TfLiteRegistration> node_and_registration) {
|
||||
return node_and_registration.first.delegate == delegate;
|
||||
});
|
||||
}
|
||||
|
||||
// Returns the number of partitions associated, as result of a call to
|
||||
// ModifyGraphWithDelegate, to the given delegate.
|
||||
int CountPartitionsDelegatedTo(Interpreter* interpreter,
|
||||
const TfLiteDelegate* delegate) {
|
||||
int result = 0;
|
||||
for (int i = 0; i < interpreter->subgraphs_size(); i++) {
|
||||
Subgraph* subgraph = interpreter->subgraph(i);
|
||||
|
||||
result += CountPartitionsDelegatedTo(subgraph, delegate);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
|
||||
absl::optional<NnapiAccelerationTestParams> validation_params =
|
||||
GetNnapiAccelerationTestParam(test_id);
|
||||
if (!validation_params.has_value()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const NnApi* nnapi = NnApiImplementation();
|
||||
if (nnapi && nnapi->nnapi_exists &&
|
||||
nnapi->android_sdk_version >=
|
||||
validation_params.value().MinAndroidSdkVersion()) {
|
||||
EXPECT_EQ(
|
||||
CountPartitionsDelegatedTo(interpreter_.get(), TestNnApiDelegate()), 1)
|
||||
<< "Expecting operation to be accelerated but cannot find a partition "
|
||||
"associated to the NNAPI delegate";
|
||||
}
|
||||
}
|
||||
|
||||
void SingleOpModel::ValidateAcceleration() {
|
||||
if (force_use_nnapi) {
|
||||
ExpectOpAcceleratedWithNnapi(GetCurrentTestId());
|
||||
}
|
||||
}
|
||||
|
||||
SingleOpModel::~SingleOpModel() { ValidateAcceleration(); }
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
@ -142,7 +143,7 @@ class SingleOpResolver : public OpResolver {
|
||||
class SingleOpModel {
|
||||
public:
|
||||
SingleOpModel() {}
|
||||
~SingleOpModel() {}
|
||||
~SingleOpModel();
|
||||
|
||||
// Set a function callback that is run right after graph is prepared
|
||||
// that allows applying external delegates. This is useful for testing
|
||||
@ -549,6 +550,34 @@ class SingleOpModel {
|
||||
return q;
|
||||
}
|
||||
|
||||
// Checks if acceleration has been done as expected.
|
||||
// Currently supports only NNAPI.
|
||||
// It verifies if the test was configured to run with NNAPI acceleration
|
||||
// or not (SetForceUseNnapi(true)).
|
||||
// In affirmative case it checks if:
|
||||
// - the test case has been listed in the list of nnapi-accelerated cases
|
||||
// - the test is running on a device (NNAPI has been loaded)
|
||||
//
|
||||
// The list of nnapi-accelerated test cases is a file containing regex to
|
||||
// include or exclude specific test cases plus the minimum android SDK version
|
||||
// the acceleration should be enabled for. For example:
|
||||
// To enable the test BorderFloat in TopKV2OpTest only from
|
||||
// android_sdk_version 29:
|
||||
//
|
||||
// TopKV2OpTest/BorderFloat,29
|
||||
//
|
||||
// And to have it always excluded while enabling all other Float tests
|
||||
// (the order of the rules is important, the first one matching is used):
|
||||
//
|
||||
// -TopKV2OpTest/BorderFloat
|
||||
// TopKV2OpTest/.+Float
|
||||
|
||||
void ValidateAcceleration();
|
||||
|
||||
// If the test was configured to use NNAPI and NNAPI was actually loaded,
|
||||
// checks if the single operation in the model has been accelerated.
|
||||
void ExpectOpAcceleratedWithNnapi(const std::string& test_id);
|
||||
|
||||
std::map<int, TensorData> tensor_data_;
|
||||
std::vector<int32_t> inputs_;
|
||||
std::vector<int32_t> outputs_;
|
||||
|
Loading…
Reference in New Issue
Block a user