Remove the redundant GetDelegates function as we could now get a delegate instance from registered delegate providers.

PiperOrigin-RevId: 301533071
Change-Id: Ia2a5c80523e5f6d1246898aafaefa20228b2bdb6
This commit is contained in:
Chao Mei 2020-03-18 00:00:58 -07:00 committed by TensorFlower Gardener
parent da95968a57
commit 2d11ad74ca
2 changed files with 10 additions and 27 deletions

View File

@ -615,11 +615,14 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
interpreter_->UseNNAPI(params_.Get<bool>("use_legacy_nnapi"));
interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
delegates_ = GetDelegates();
for (const auto& delegate : delegates_) {
if (interpreter_->ModifyGraphWithDelegate(delegate.second.get()) !=
kTfLiteOk) {
TFLITE_LOG(ERROR) << "Failed to apply " << delegate.first << " delegate.";
for (const auto& delegate_provider : GetRegisteredDelegateProviders()) {
auto delegate = delegate_provider->CreateTfLiteDelegate(params_);
// It's possible that a delegate of certain type won't be created as
// user-specified benchmark params tells not to.
if (delegate == nullptr) continue;
if (interpreter_->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
TFLITE_LOG(ERROR) << "Failed to apply " << delegate_provider->GetName()
<< " delegate.";
return kTfLiteError;
} else {
bool fully_delegated = true;
@ -629,7 +632,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
int first_node_id = interpreter_->execution_plan()[0];
const TfLiteNode first_node =
interpreter_->node_and_registration(first_node_id)->first;
if (delegate.second.get() != first_node.delegate) {
if (delegate.get() != first_node.delegate) {
fully_delegated = false;
}
}
@ -639,7 +642,7 @@ TfLiteStatus BenchmarkTfLiteModel::Init() {
}
const std::string delegate_status =
fully_delegated ? "completely" : "partially";
TFLITE_LOG(INFO) << "Applied " << delegate.first
TFLITE_LOG(INFO) << "Applied " << delegate_provider->GetName()
<< " delegate, and the model graph will be "
<< delegate_status << " executed w/ the delegate.";
}
@ -698,19 +701,6 @@ TfLiteStatus BenchmarkTfLiteModel::LoadModel() {
return kTfLiteOk;
}
BenchmarkTfLiteModel::TfLiteDelegatePtrMap BenchmarkTfLiteModel::GetDelegates()
const {
TfLiteDelegatePtrMap delegates;
for (const auto& delegate_util : GetRegisteredDelegateProviders()) {
auto delegate = delegate_util->CreateTfLiteDelegate(params_);
if (delegate != nullptr) {
delegates.emplace(delegate_util->GetName(), std::move(delegate));
}
}
return delegates;
}
std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
const {
auto resolver = new tflite::ops::builtin::BuiltinOpResolver();

View File

@ -69,11 +69,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
int64_t MayGetModelFileSize() override;
// Allow subclasses to create custom delegates to be applied during init.
using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
virtual TfLiteDelegatePtrMap GetDelegates() const;
virtual TfLiteStatus LoadModel();
// Allow subclasses to create a customized Op resolver during init.
@ -123,8 +118,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<InputTensorData> inputs_data_;
std::unique_ptr<BenchmarkListener> profiling_listener_ = nullptr;
std::unique_ptr<BenchmarkListener> ruy_profiling_listener_ = nullptr;
TfLiteDelegatePtrMap delegates_;
std::mt19937 random_engine_;
};