Remove the redundant GetDelegates function as we could now get a delegate instance from registered delegate providers.

PiperOrigin-RevId: 301533071
Change-Id: Ia2a5c80523e5f6d1246898aafaefa20228b2bdb6
This commit is contained in:
Chao Mei 2020-03-18 00:00:58 -07:00 committed by TensorFlower Gardener
parent da95968a57
commit 2d11ad74ca
2 changed files with 10 additions and 27 deletions

View File

@ -615,11 +615,14 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi")); interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16")); interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
delegates_ = GetDelegates(); for (const auto& delegate_provider : GetRegisteredDelegateProviders()) {
for (const auto& delegate : delegates_) { auto delegate = delegate_provider->CreateTfLiteDelegate(params_);
if (interpreter_->ModifyGraphWithDelegate(delegate.second.get()) != // It's possible that a delegate of certain type won't be created as
kTfLiteOk) { // user-specified benchmark params tells not to.
TFLITE_LOG(ERROR) << "Failed to apply " << delegate.first << " delegate."; if (delegate == nullptr) continue;
if (interpreter_->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
TFLITE_LOG(ERROR) << "Failed to apply " << delegate_provider->GetName()
<< " delegate.";
return kTfLiteError; return kTfLiteError;
} else { } else {
bool fully_delegated = true; bool fully_delegated = true;
@ -629,7 +632,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
int first_node_id = interpreter_->execution_plan()[0]; int first_node_id = interpreter_->execution_plan()[0];
const TfLiteNode first_node = const TfLiteNode first_node =
interpreter_->node_and_registration(first_node_id)->first; interpreter_->node_and_registration(first_node_id)->first;
if (delegate.second.get() != first_node.delegate) { if (delegate.get() != first_node.delegate) {
fully_delegated = false; fully_delegated = false;
} }
} }
@ -639,7 +642,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
} }
const std::string delegate_status = const std::string delegate_status =
fully_delegated ? "completely" : "partially"; fully_delegated ? "completely" : "partially";
TFLITE_LOG(INFO) << "Applied " << delegate.first TFLITE_LOG(INFO) << "Applied " << delegate_provider->GetName()
<< " delegate, and the model graph will be " << " delegate, and the model graph will be "
<< delegate_status << " executed w/ the delegate."; << delegate_status << " executed w/ the delegate.";
} }
@ -698,19 +701,6 @@ TfLiteStatus BenchmarkTfLiteModel::LoadModel() {
return kTfLiteOk; return kTfLiteOk;
} }
BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
const {
TfLiteDelegatePtrMap delegates;
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
auto delegate = delegate_util->CreateTfLiteDelegate(params_);
if (delegate != nullptr) {
delegates.emplace(delegate_util->GetName(), std::move(delegate));
}
}
return delegates;
}
std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver() std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
const { const {
auto resolver = new tflite::ops::builtin::BuiltinOpResolver(); auto resolver = new tflite::ops::builtin::BuiltinOpResolver();

View File

@ -69,11 +69,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
int64_t MayGetModelFileSize() override; int64_t MayGetModelFileSize() override;
// Allow subclasses to create custom delegates to be applied during init.
using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
virtual TfLiteDelegatePtrMap GetDelegates() const;
virtual TfLiteStatus LoadModel(); virtual TfLiteStatus LoadModel();
// Allow subclasses to create a customized Op resolver during init. // Allow subclasses to create a customized Op resolver during init.
@ -123,8 +118,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<InputTensorData> inputs_data_; std::vector<InputTensorData> inputs_data_;
std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr; std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr;
std::unique_ptr<BenchmarkListener> ruy_profiling_listener_ = nullptr; std::unique_ptr<BenchmarkListener> ruy_profiling_listener_ = nullptr;
TfLiteDelegatePtrMap delegates_;
std::mt19937 random_engine_; std::mt19937 random_engine_;
}; };