Remove defined(__ANDROID__) in nnapi_delegate_provider as it's been handled by the NNAPI delegate implementation.

PiperOrigin-RevId: 320321656
Change-Id: Ie41f85039f8784d13f59e692c3f0d43411d0aba6
This commit is contained in:
Chao Mei 2020-07-08 20:46:45 -07:00 committed by TensorFlower Gardener
parent 5232497868
commit 30f6b22eb4
4 changed files with 37 additions and 36 deletions

View File

@ -23,9 +23,16 @@ TfLiteDelegate* NnApiDelegate() {
return delegate; return delegate;
} }
StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* /* nnapi */)
: StatefulNnApiDelegate() {}
StatefulNnApiDelegate::StatefulNnApiDelegate(Options /* options */) StatefulNnApiDelegate::StatefulNnApiDelegate(Options /* options */)
: StatefulNnApiDelegate() {} : StatefulNnApiDelegate() {}
StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* /* nnapi */,
Options /* options */)
: StatefulNnApiDelegate() {}
StatefulNnApiDelegate::StatefulNnApiDelegate() StatefulNnApiDelegate::StatefulNnApiDelegate()
: TfLiteDelegate(TfLiteDelegateCreate()), : TfLiteDelegate(TfLiteDelegateCreate()),
delegate_data_(/*nnapi=*/nullptr) { delegate_data_(/*nnapi=*/nullptr) {

View File

@ -81,7 +81,9 @@ cc_library(
copts = common_copts, copts = common_copts,
deps = [ deps = [
":delegate_provider_hdr", ":delegate_provider_hdr",
"//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/nnapi:nnapi_util",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -14,11 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string> #include <string>
#include "tensorflow/lite/tools/delegates/delegate_provider.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/tools/evaluation/utils.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h"
#if defined(__ANDROID__)
#include "tensorflow/lite/nnapi/nnapi_util.h" #include "tensorflow/lite/nnapi/nnapi_util.h"
#endif #include "tensorflow/lite/tools/delegates/delegate_provider.h"
namespace tflite { namespace tflite {
namespace tools { namespace tools {
@ -26,7 +25,6 @@ namespace tools {
class NnapiDelegateProvider : public DelegateProvider { class NnapiDelegateProvider : public DelegateProvider {
public: public:
NnapiDelegateProvider() { NnapiDelegateProvider() {
#if defined(__ANDROID__)
default_params_.AddParam("use_nnapi", ToolParam::Create<bool>(false)); default_params_.AddParam("use_nnapi", ToolParam::Create<bool>(false));
default_params_.AddParam("nnapi_execution_preference", default_params_.AddParam("nnapi_execution_preference",
ToolParam::Create<std::string>("")); ToolParam::Create<std::string>(""));
@ -38,7 +36,6 @@ class NnapiDelegateProvider : public DelegateProvider {
ToolParam::Create<bool>(false)); ToolParam::Create<bool>(false));
default_params_.AddParam("nnapi_allow_fp16", default_params_.AddParam("nnapi_allow_fp16",
ToolParam::Create<bool>(false)); ToolParam::Create<bool>(false));
#endif
} }
std::vector<Flag> CreateFlags(ToolParams* params) const final; std::vector<Flag> CreateFlags(ToolParams* params) const final;
@ -53,32 +50,28 @@ REGISTER_DELEGATE_PROVIDER(NnapiDelegateProvider);
std::vector<Flag> NnapiDelegateProvider::CreateFlags(ToolParams* params) const { std::vector<Flag> NnapiDelegateProvider::CreateFlags(ToolParams* params) const {
std::vector<Flag> flags = { std::vector<Flag> flags = {
#if defined(__ANDROID__) CreateFlag<bool>("use_nnapi", params, "use nnapi delegate api"),
CreateFlag<bool>("use_nnapi", params, "use nnapi delegate api"), CreateFlag<std::string>("nnapi_execution_preference", params,
CreateFlag<std::string>("nnapi_execution_preference", params, "execution preference for nnapi delegate. Should "
"execution preference for nnapi delegate. Should " "be one of the following: fast_single_answer, "
"be one of the following: fast_single_answer, " "sustained_speed, low_power, undefined"),
"sustained_speed, low_power, undefined"), CreateFlag<std::string>("nnapi_execution_priority", params,
CreateFlag<std::string>("nnapi_execution_priority", params, "The model execution priority in nnapi, and it "
"The model execution priority in nnapi, and it " "should be one of the following: default, low, "
"should be one of the following: default, low, " "medium and high. This requires Android 11+."),
"medium and high. This requires Android 11+."), CreateFlag<std::string>(
CreateFlag<std::string>( "nnapi_accelerator_name", params,
"nnapi_accelerator_name", params, "the name of the nnapi accelerator to use (requires Android Q+)"),
"the name of the nnapi accelerator to use (requires Android Q+)"), CreateFlag<bool>("disable_nnapi_cpu", params,
CreateFlag<bool>("disable_nnapi_cpu", params, "Disable the NNAPI CPU device"),
"Disable the NNAPI CPU device"), CreateFlag<bool>("nnapi_allow_fp16", params,
CreateFlag<bool>("nnapi_allow_fp16", params, "Allow fp32 computation to be run in fp16")};
"Allow fp32 computation to be run in fp16")
#endif
};
return flags; return flags;
} }
void NnapiDelegateProvider::LogParams(const ToolParams& params, void NnapiDelegateProvider::LogParams(const ToolParams& params,
bool verbose) const { bool verbose) const {
#if defined(__ANDROID__)
LOG_TOOL_PARAM(params, bool, "use_nnapi", "Use NNAPI", verbose); LOG_TOOL_PARAM(params, bool, "use_nnapi", "Use NNAPI", verbose);
if (!params.Get<bool>("use_nnapi")) return; if (!params.Get<bool>("use_nnapi")) return;
@ -100,13 +93,11 @@ void NnapiDelegateProvider::LogParams(const ToolParams& params,
verbose); verbose);
LOG_TOOL_PARAM(params, bool, "nnapi_allow_fp16", "Allow fp16 in NNAPI", LOG_TOOL_PARAM(params, bool, "nnapi_allow_fp16", "Allow fp16 in NNAPI",
verbose); verbose);
#endif
} }
TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate( TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
const ToolParams& params) const { const ToolParams& params) const {
TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {}); TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {});
#if defined(__ANDROID__)
if (params.Get<bool>("use_nnapi")) { if (params.Get<bool>("use_nnapi")) {
StatefulNnApiDelegate::Options options; StatefulNnApiDelegate::Options options;
std::string accelerator_name = std::string accelerator_name =
@ -174,10 +165,16 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
if (max_delegated_partitions >= 0) { if (max_delegated_partitions >= 0) {
options.max_number_delegated_partitions = max_delegated_partitions; options.max_number_delegated_partitions = max_delegated_partitions;
} }
delegate = evaluation::CreateNNAPIDelegate(options); const auto* nnapi_impl = NnApiImplementation();
if (!delegate.get()) { if (!nnapi_impl->nnapi_exists) {
TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform."; TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform.";
return delegate;
} }
return TfLiteDelegatePtr(
new StatefulNnApiDelegate(nnapi_impl, options),
[](TfLiteDelegate* delegate) {
delete reinterpret_cast<StatefulNnApiDelegate*>(delegate);
});
} else if (!params.Get<std::string>("nnapi_accelerator_name").empty()) { } else if (!params.Get<std::string>("nnapi_accelerator_name").empty()) {
TFLITE_LOG(WARN) TFLITE_LOG(WARN)
<< "`--use_nnapi=true` must be set for the provided NNAPI accelerator (" << "`--use_nnapi=true` must be set for the provided NNAPI accelerator ("
@ -188,7 +185,6 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate(
<< params.Get<std::string>("nnapi_execution_preference") << params.Get<std::string>("nnapi_execution_preference")
<< ") to be used."; << ") to be used.";
} }
#endif
return delegate; return delegate;
} }

View File

@ -43,11 +43,7 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) {
TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) { TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) {
DelegateProviders providers; DelegateProviders providers;
const auto& params = providers.GetAllParams(); const auto& params = providers.GetAllParams();
#if defined(__ANDROID__)
EXPECT_TRUE(params.HasParam("use_nnapi")); EXPECT_TRUE(params.HasParam("use_nnapi"));
#else
EXPECT_FALSE(params.HasParam("use_nnapi"));
#endif
EXPECT_TRUE(params.HasParam("use_gpu")); EXPECT_TRUE(params.HasParam("use_gpu"));
int argc = 3; int argc = 3;