Reimplement Coral plugin in Acceleration@Scale using delegate
PiperOrigin-RevId: 355061584 Change-Id: I9ba1837fcd73356a990fb647edd16747f6eb0a48
This commit is contained in:
parent
f3e7ae3965
commit
636990dd7a
tensorflow/lite/experimental/acceleration/configuration
@ -96,12 +96,11 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "delegate_registry",
|
||||
srcs = ["delegate_registry.cc"],
|
||||
hdrs = ["delegate_registry.h"],
|
||||
deps = [
|
||||
":configuration_fbs",
|
||||
"//tensorflow/lite:mutable_op_resolver",
|
||||
"//tensorflow/lite/c:common",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
@ -44,20 +44,16 @@ enum ExecutionPreference {
|
||||
FORCE_CPU = 3;
|
||||
}
|
||||
|
||||
// TFLite accelerator to use. It can be either a delegate or an external
|
||||
// context.
|
||||
// TFLite accelerator to use.
|
||||
enum Delegate {
|
||||
NONE = 0;
|
||||
|
||||
// DELEGATE OPTIONS.
|
||||
NNAPI = 1;
|
||||
GPU = 2;
|
||||
HEXAGON = 3;
|
||||
XNNPACK = 4;
|
||||
// The EdgeTpu in Pixel devices.
|
||||
EDGETPU = 5;
|
||||
|
||||
// EXTERNAL CONTEXT OPTIONS.
|
||||
// The Coral EdgeTpu Dev Board / USB accelerator.
|
||||
EDGETPU_CORAL = 6;
|
||||
}
|
||||
@ -263,10 +259,10 @@ message EdgeTpuSettings {
|
||||
optional EdgeTpuDeviceSpec edgetpu_device_spec = 4;
|
||||
}
|
||||
|
||||
// Coral Dev Board / USB accelerator external context settings.
|
||||
// Coral Dev Board / USB accelerator delegate settings.
|
||||
//
|
||||
// See
|
||||
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu.h
|
||||
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu_c.h
|
||||
message CoralSettings {
|
||||
enum Performance {
|
||||
UNDEFINED = 0;
|
||||
@ -320,7 +316,7 @@ message TFLiteSettings {
|
||||
// For configuring the EdgeTpuDelegate.
|
||||
optional EdgeTpuSettings edgetpu_settings = 8;
|
||||
|
||||
// For configuring the Coral External Context (EdgeTpuContext).
|
||||
// For configuring the Coral EdgeTpu Delegate.
|
||||
optional CoralSettings coral_settings = 10;
|
||||
|
||||
// Whether to automatically fall back to TFLite CPU path.
|
||||
|
@ -0,0 +1,56 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
|
||||
void DelegatePluginRegistry::RegisterImpl(
|
||||
const std::string& name,
|
||||
std::function<
|
||||
std::unique_ptr<DelegatePluginInterface>(const TFLiteSettings&)>
|
||||
creator_function) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
factories_[name] = creator_function;
|
||||
}
|
||||
|
||||
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateImpl(
|
||||
const std::string& name, const TFLiteSettings& settings) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
auto it = factories_.find(name);
|
||||
return (it != factories_.end()) ? it->second(settings) : nullptr;
|
||||
}
|
||||
|
||||
DelegatePluginRegistry* DelegatePluginRegistry::GetSingleton() {
|
||||
static auto* instance = new DelegatePluginRegistry();
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateByName(
|
||||
const std::string& name, const TFLiteSettings& settings) {
|
||||
auto* const instance = DelegatePluginRegistry::GetSingleton();
|
||||
return instance->CreateImpl(name, settings);
|
||||
}
|
||||
|
||||
DelegatePluginRegistry::Register::Register(const std::string& name,
|
||||
CreatorFunction creator_function) {
|
||||
auto* const instance = DelegatePluginRegistry::GetSingleton();
|
||||
instance->RegisterImpl(name, creator_function);
|
||||
}
|
||||
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
@ -18,11 +18,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
// Defines an interface for TFLite delegate plugins.
|
||||
//
|
||||
@ -46,101 +44,53 @@ namespace delegates {
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
||||
|
||||
// A shared pointer to `TfLiteExternalContext`, similar to `TfLiteDelegatePtr`.
|
||||
using TfLiteExternalContextPtr = std::shared_ptr<TfLiteExternalContext>;
|
||||
|
||||
template <typename AcceleratorType, typename AcceleratorPtrType>
|
||||
class AcceleratorPluginInterface {
|
||||
class DelegatePluginInterface {
|
||||
public:
|
||||
virtual AcceleratorPtrType Create() = 0;
|
||||
// Some accelerators require their own custom ops, such as the Coral plugin.
|
||||
// Default to an empty MutableOpResolver.
|
||||
virtual std::unique_ptr<MutableOpResolver> CreateOpResolver() {
|
||||
return absl::make_unique<MutableOpResolver>();
|
||||
}
|
||||
virtual int GetDelegateErrno(AcceleratorType* from_delegate) = 0;
|
||||
virtual ~AcceleratorPluginInterface() = default;
|
||||
virtual TfLiteDelegatePtr Create() = 0;
|
||||
virtual int GetDelegateErrno(TfLiteDelegate* from_delegate) = 0;
|
||||
virtual ~DelegatePluginInterface() = default;
|
||||
};
|
||||
|
||||
// `AcceleratorPluginInterface` implemented for `TfLiteDelegate`.
|
||||
using DelegatePluginInterface =
|
||||
AcceleratorPluginInterface<TfLiteDelegate, TfLiteDelegatePtr>;
|
||||
// `AcceleratorPluginInterface` implemented for `TfLiteExternalContext`.
|
||||
using ContextPluginInterface =
|
||||
AcceleratorPluginInterface<TfLiteExternalContext, TfLiteExternalContextPtr>;
|
||||
|
||||
// A stripped-down registry that allows delegate plugins to be created by name.
|
||||
//
|
||||
// Limitations:
|
||||
// - Doesn't allow deregistration.
|
||||
// - Doesn't check for duplication registration.
|
||||
//
|
||||
template <typename AcceleratorType, typename AcceleratorPtrType>
|
||||
class AcceleratorRegistry {
|
||||
class DelegatePluginRegistry {
|
||||
public:
|
||||
typedef std::function<std::unique_ptr<AcceleratorPluginInterface<
|
||||
AcceleratorType, AcceleratorPtrType>>(const TFLiteSettings&)>
|
||||
typedef std::function<std::unique_ptr<DelegatePluginInterface>(
|
||||
const TFLiteSettings&)>
|
||||
CreatorFunction;
|
||||
// Returns a AcceleratorPluginInterface registered with `name` or nullptr if
|
||||
// no matching plugin found. TFLiteSettings is per-plugin, so that the
|
||||
// corresponding delegate options data lifetime is maintained.
|
||||
static std::unique_ptr<
|
||||
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
|
||||
CreateByName(const std::string& name, const TFLiteSettings& settings) {
|
||||
auto* const instance = AcceleratorRegistry::GetSingleton();
|
||||
return instance->CreateImpl(name, settings);
|
||||
}
|
||||
// Returns a DelegatePluginInterface registered with `name` or nullptr if no
|
||||
// matching plugin found.
|
||||
// TFLiteSettings is per-plugin, so that the corresponding delegate options
|
||||
// data lifetime is maintained.
|
||||
static std::unique_ptr<DelegatePluginInterface> CreateByName(
|
||||
const std::string& name, const TFLiteSettings& settings);
|
||||
|
||||
// Struct to be statically allocated for registration.
|
||||
struct Register {
|
||||
Register(const std::string& name, CreatorFunction creator_function) {
|
||||
auto* const instance = AcceleratorRegistry::GetSingleton();
|
||||
instance->RegisterImpl(name, creator_function);
|
||||
}
|
||||
Register(const std::string& name, CreatorFunction creator_function);
|
||||
};
|
||||
|
||||
private:
|
||||
void RegisterImpl(const std::string& name, CreatorFunction creator_function) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
factories_[name] = creator_function;
|
||||
}
|
||||
|
||||
std::unique_ptr<
|
||||
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
|
||||
CreateImpl(const std::string& name, const TFLiteSettings& settings) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
auto it = factories_.find(name);
|
||||
return (it != factories_.end()) ? it->second(settings) : nullptr;
|
||||
}
|
||||
|
||||
static AcceleratorRegistry* GetSingleton() {
|
||||
static auto* instance = new AcceleratorRegistry();
|
||||
return instance;
|
||||
}
|
||||
|
||||
void RegisterImpl(const std::string& name, CreatorFunction creator_function);
|
||||
std::unique_ptr<DelegatePluginInterface> CreateImpl(
|
||||
const std::string& name, const TFLiteSettings& settings);
|
||||
static DelegatePluginRegistry* GetSingleton();
|
||||
absl::Mutex mutex_;
|
||||
std::unordered_map<std::string, CreatorFunction> factories_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
using DelegatePluginRegistry =
|
||||
AcceleratorRegistry<TfLiteDelegate, TfLiteDelegatePtr>;
|
||||
using ContextPluginRegistry =
|
||||
AcceleratorRegistry<TfLiteExternalContext, TfLiteExternalContextPtr>;
|
||||
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
||||
#define TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
|
||||
name, f, accelerator_type, accelerator_ptr_type) \
|
||||
static auto* g_delegate_plugin_##name##_ = \
|
||||
new AcceleratorRegistry<accelerator_type, \
|
||||
accelerator_ptr_type>::Register(#name, f);
|
||||
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
|
||||
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME(name, f, TfLiteDelegate, \
|
||||
TfLiteDelegatePtr);
|
||||
#define TFLITE_REGISTER_EXTERNAL_CONTEXT_FACTORY_FUNCTION(name, f) \
|
||||
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
|
||||
name, f, TfLiteExternalContext, TfLiteExternalContextPtr);
|
||||
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f) \
|
||||
static auto* g_delegate_plugin_##name##_ = \
|
||||
new DelegatePluginRegistry::Register(#name, f);
|
||||
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
|
||||
TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f);
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_
|
||||
|
Loading…
Reference in New Issue
Block a user