Reimplement Coral plugin in Acceleration@Scale using delegate

PiperOrigin-RevId: 355061584
Change-Id: I9ba1837fcd73356a990fb647edd16747f6eb0a48
This commit is contained in:
Lu Wang 2021-02-01 17:24:21 -08:00 committed by TensorFlower Gardener
parent f3e7ae3965
commit 636990dd7a
4 changed files with 84 additions and 83 deletions
tensorflow/lite/experimental/acceleration/configuration

View File

@ -96,12 +96,11 @@ cc_library(
cc_library(
name = "delegate_registry",
srcs = ["delegate_registry.cc"],
hdrs = ["delegate_registry.h"],
deps = [
":configuration_fbs",
"//tensorflow/lite:mutable_op_resolver",
"//tensorflow/lite/c:common",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -44,20 +44,16 @@ enum ExecutionPreference {
FORCE_CPU = 3;
}
// TFLite accelerator to use. It can be either a delegate or an external
// context.
// TFLite accelerator to use.
enum Delegate {
NONE = 0;
// DELEGATE OPTIONS.
NNAPI = 1;
GPU = 2;
HEXAGON = 3;
XNNPACK = 4;
// The EdgeTpu in Pixel devices.
EDGETPU = 5;
// EXTERNAL CONTEXT OPTIONS.
// The Coral EdgeTpu Dev Board / USB accelerator.
EDGETPU_CORAL = 6;
}
@ -263,10 +259,10 @@ message EdgeTpuSettings {
optional EdgeTpuDeviceSpec edgetpu_device_spec = 4;
}
// Coral Dev Board / USB accelerator external context settings.
// Coral Dev Board / USB accelerator delegate settings.
//
// See
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu.h
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu_c.h
message CoralSettings {
enum Performance {
UNDEFINED = 0;
@ -320,7 +316,7 @@ message TFLiteSettings {
// For configuring the EdgeTpuDelegate.
optional EdgeTpuSettings edgetpu_settings = 8;
// For configuring the Coral External Context (EdgeTpuContext).
// For configuring the Coral EdgeTpu Delegate.
optional CoralSettings coral_settings = 10;
// Whether to automatically fall back to TFLite CPU path.

View File

@ -0,0 +1,56 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
#include "absl/synchronization/mutex.h"
namespace tflite {
namespace delegates {
void DelegatePluginRegistry::RegisterImpl(
const std::string& name,
std::function<
std::unique_ptr<DelegatePluginInterface>(const TFLiteSettings&)>
creator_function) {
absl::MutexLock lock(&mutex_);
factories_[name] = creator_function;
}
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateImpl(
const std::string& name, const TFLiteSettings& settings) {
absl::MutexLock lock(&mutex_);
auto it = factories_.find(name);
return (it != factories_.end()) ? it->second(settings) : nullptr;
}
DelegatePluginRegistry* DelegatePluginRegistry::GetSingleton() {
static auto* instance = new DelegatePluginRegistry();
return instance;
}
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateByName(
const std::string& name, const TFLiteSettings& settings) {
auto* const instance = DelegatePluginRegistry::GetSingleton();
return instance->CreateImpl(name, settings);
}
DelegatePluginRegistry::Register::Register(const std::string& name,
CreatorFunction creator_function) {
auto* const instance = DelegatePluginRegistry::GetSingleton();
instance->RegisterImpl(name, creator_function);
}
} // namespace delegates
} // namespace tflite

View File

@ -18,11 +18,9 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
#include "tensorflow/lite/mutable_op_resolver.h"
// Defines an interface for TFLite delegate plugins.
//
@ -46,101 +44,53 @@ namespace delegates {
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// A shared pointer to `TfLiteExternalContext`, similar to `TfLiteDelegatePtr`.
using TfLiteExternalContextPtr = std::shared_ptr<TfLiteExternalContext>;
template <typename AcceleratorType, typename AcceleratorPtrType>
class AcceleratorPluginInterface {
class DelegatePluginInterface {
public:
virtual AcceleratorPtrType Create() = 0;
// Some accelerators require their own custom ops, such as the Coral plugin.
// Default to an empty MutableOpResolver.
virtual std::unique_ptr<MutableOpResolver> CreateOpResolver() {
return absl::make_unique<MutableOpResolver>();
}
virtual int GetDelegateErrno(AcceleratorType* from_delegate) = 0;
virtual ~AcceleratorPluginInterface() = default;
virtual TfLiteDelegatePtr Create() = 0;
virtual int GetDelegateErrno(TfLiteDelegate* from_delegate) = 0;
virtual ~DelegatePluginInterface() = default;
};
// `AcceleratorPluginInterface` implemented for `TfLiteDelegate`.
using DelegatePluginInterface =
AcceleratorPluginInterface<TfLiteDelegate, TfLiteDelegatePtr>;
// `AcceleratorPluginInterface` implemented for `TfLiteExternalContext`.
using ContextPluginInterface =
AcceleratorPluginInterface<TfLiteExternalContext, TfLiteExternalContextPtr>;
// A stripped-down registry that allows delegate plugins to be created by name.
//
// Limitations:
// - Doesn't allow deregistration.
// - Doesn't check for duplication registration.
//
template <typename AcceleratorType, typename AcceleratorPtrType>
class AcceleratorRegistry {
class DelegatePluginRegistry {
public:
typedef std::function<std::unique_ptr<AcceleratorPluginInterface<
AcceleratorType, AcceleratorPtrType>>(const TFLiteSettings&)>
typedef std::function<std::unique_ptr<DelegatePluginInterface>(
const TFLiteSettings&)>
CreatorFunction;
// Returns a AcceleratorPluginInterface registered with `name` or nullptr if
// no matching plugin found. TFLiteSettings is per-plugin, so that the
// corresponding delegate options data lifetime is maintained.
static std::unique_ptr<
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
CreateByName(const std::string& name, const TFLiteSettings& settings) {
auto* const instance = AcceleratorRegistry::GetSingleton();
return instance->CreateImpl(name, settings);
}
// Returns a DelegatePluginInterface registered with `name` or nullptr if no
// matching plugin found.
// TFLiteSettings is per-plugin, so that the corresponding delegate options
// data lifetime is maintained.
static std::unique_ptr<DelegatePluginInterface> CreateByName(
const std::string& name, const TFLiteSettings& settings);
// Struct to be statically allocated for registration.
struct Register {
Register(const std::string& name, CreatorFunction creator_function) {
auto* const instance = AcceleratorRegistry::GetSingleton();
instance->RegisterImpl(name, creator_function);
}
Register(const std::string& name, CreatorFunction creator_function);
};
private:
void RegisterImpl(const std::string& name, CreatorFunction creator_function) {
absl::MutexLock lock(&mutex_);
factories_[name] = creator_function;
}
std::unique_ptr<
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
CreateImpl(const std::string& name, const TFLiteSettings& settings) {
absl::MutexLock lock(&mutex_);
auto it = factories_.find(name);
return (it != factories_.end()) ? it->second(settings) : nullptr;
}
static AcceleratorRegistry* GetSingleton() {
static auto* instance = new AcceleratorRegistry();
return instance;
}
void RegisterImpl(const std::string& name, CreatorFunction creator_function);
std::unique_ptr<DelegatePluginInterface> CreateImpl(
const std::string& name, const TFLiteSettings& settings);
static DelegatePluginRegistry* GetSingleton();
absl::Mutex mutex_;
std::unordered_map<std::string, CreatorFunction> factories_
ABSL_GUARDED_BY(mutex_);
};
using DelegatePluginRegistry =
AcceleratorRegistry<TfLiteDelegate, TfLiteDelegatePtr>;
using ContextPluginRegistry =
AcceleratorRegistry<TfLiteExternalContext, TfLiteExternalContextPtr>;
} // namespace delegates
} // namespace tflite
#define TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
name, f, accelerator_type, accelerator_ptr_type) \
static auto* g_delegate_plugin_##name##_ = \
new AcceleratorRegistry<accelerator_type, \
accelerator_ptr_type>::Register(#name, f);
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME(name, f, TfLiteDelegate, \
TfLiteDelegatePtr);
#define TFLITE_REGISTER_EXTERNAL_CONTEXT_FACTORY_FUNCTION(name, f) \
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
name, f, TfLiteExternalContext, TfLiteExternalContextPtr);
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f) \
static auto* g_delegate_plugin_##name##_ = \
new DelegatePluginRegistry::Register(#name, f);
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f);
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_