Reimplement Coral plugin in Acceleration@Scale using delegate
PiperOrigin-RevId: 355061584 Change-Id: I9ba1837fcd73356a990fb647edd16747f6eb0a48
This commit is contained in:
parent
f3e7ae3965
commit
636990dd7a
@ -96,12 +96,11 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "delegate_registry",
|
name = "delegate_registry",
|
||||||
|
srcs = ["delegate_registry.cc"],
|
||||||
hdrs = ["delegate_registry.h"],
|
hdrs = ["delegate_registry.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":configuration_fbs",
|
":configuration_fbs",
|
||||||
"//tensorflow/lite:mutable_op_resolver",
|
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -44,20 +44,16 @@ enum ExecutionPreference {
|
|||||||
FORCE_CPU = 3;
|
FORCE_CPU = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TFLite accelerator to use. It can be either a delegate or an external
|
// TFLite accelerator to use.
|
||||||
// context.
|
|
||||||
enum Delegate {
|
enum Delegate {
|
||||||
NONE = 0;
|
NONE = 0;
|
||||||
|
|
||||||
// DELEGATE OPTIONS.
|
|
||||||
NNAPI = 1;
|
NNAPI = 1;
|
||||||
GPU = 2;
|
GPU = 2;
|
||||||
HEXAGON = 3;
|
HEXAGON = 3;
|
||||||
XNNPACK = 4;
|
XNNPACK = 4;
|
||||||
// The EdgeTpu in Pixel devices.
|
// The EdgeTpu in Pixel devices.
|
||||||
EDGETPU = 5;
|
EDGETPU = 5;
|
||||||
|
|
||||||
// EXTERNAL CONTEXT OPTIONS.
|
|
||||||
// The Coral EdgeTpu Dev Board / USB accelerator.
|
// The Coral EdgeTpu Dev Board / USB accelerator.
|
||||||
EDGETPU_CORAL = 6;
|
EDGETPU_CORAL = 6;
|
||||||
}
|
}
|
||||||
@ -263,10 +259,10 @@ message EdgeTpuSettings {
|
|||||||
optional EdgeTpuDeviceSpec edgetpu_device_spec = 4;
|
optional EdgeTpuDeviceSpec edgetpu_device_spec = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Coral Dev Board / USB accelerator external context settings.
|
// Coral Dev Board / USB accelerator delegate settings.
|
||||||
//
|
//
|
||||||
// See
|
// See
|
||||||
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu.h
|
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu_c.h
|
||||||
message CoralSettings {
|
message CoralSettings {
|
||||||
enum Performance {
|
enum Performance {
|
||||||
UNDEFINED = 0;
|
UNDEFINED = 0;
|
||||||
@ -320,7 +316,7 @@ message TFLiteSettings {
|
|||||||
// For configuring the EdgeTpuDelegate.
|
// For configuring the EdgeTpuDelegate.
|
||||||
optional EdgeTpuSettings edgetpu_settings = 8;
|
optional EdgeTpuSettings edgetpu_settings = 8;
|
||||||
|
|
||||||
// For configuring the Coral External Context (EdgeTpuContext).
|
// For configuring the Coral EdgeTpu Delegate.
|
||||||
optional CoralSettings coral_settings = 10;
|
optional CoralSettings coral_settings = 10;
|
||||||
|
|
||||||
// Whether to automatically fall back to TFLite CPU path.
|
// Whether to automatically fall back to TFLite CPU path.
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace delegates {
|
||||||
|
|
||||||
|
void DelegatePluginRegistry::RegisterImpl(
|
||||||
|
const std::string& name,
|
||||||
|
std::function<
|
||||||
|
std::unique_ptr<DelegatePluginInterface>(const TFLiteSettings&)>
|
||||||
|
creator_function) {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
|
factories_[name] = creator_function;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateImpl(
|
||||||
|
const std::string& name, const TFLiteSettings& settings) {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
|
auto it = factories_.find(name);
|
||||||
|
return (it != factories_.end()) ? it->second(settings) : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
DelegatePluginRegistry* DelegatePluginRegistry::GetSingleton() {
|
||||||
|
static auto* instance = new DelegatePluginRegistry();
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateByName(
|
||||||
|
const std::string& name, const TFLiteSettings& settings) {
|
||||||
|
auto* const instance = DelegatePluginRegistry::GetSingleton();
|
||||||
|
return instance->CreateImpl(name, settings);
|
||||||
|
}
|
||||||
|
|
||||||
|
DelegatePluginRegistry::Register::Register(const std::string& name,
|
||||||
|
CreatorFunction creator_function) {
|
||||||
|
auto* const instance = DelegatePluginRegistry::GetSingleton();
|
||||||
|
instance->RegisterImpl(name, creator_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace delegates
|
||||||
|
} // namespace tflite
|
@ -18,11 +18,9 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
|
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
|
||||||
|
|
||||||
// Defines an interface for TFLite delegate plugins.
|
// Defines an interface for TFLite delegate plugins.
|
||||||
//
|
//
|
||||||
@ -46,101 +44,53 @@ namespace delegates {
|
|||||||
using TfLiteDelegatePtr =
|
using TfLiteDelegatePtr =
|
||||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
||||||
|
|
||||||
// A shared pointer to `TfLiteExternalContext`, similar to `TfLiteDelegatePtr`.
|
class DelegatePluginInterface {
|
||||||
using TfLiteExternalContextPtr = std::shared_ptr<TfLiteExternalContext>;
|
|
||||||
|
|
||||||
template <typename AcceleratorType, typename AcceleratorPtrType>
|
|
||||||
class AcceleratorPluginInterface {
|
|
||||||
public:
|
public:
|
||||||
virtual AcceleratorPtrType Create() = 0;
|
virtual TfLiteDelegatePtr Create() = 0;
|
||||||
// Some accelerators require their own custom ops, such as the Coral plugin.
|
virtual int GetDelegateErrno(TfLiteDelegate* from_delegate) = 0;
|
||||||
// Default to an empty MutableOpResolver.
|
virtual ~DelegatePluginInterface() = default;
|
||||||
virtual std::unique_ptr<MutableOpResolver> CreateOpResolver() {
|
|
||||||
return absl::make_unique<MutableOpResolver>();
|
|
||||||
}
|
|
||||||
virtual int GetDelegateErrno(AcceleratorType* from_delegate) = 0;
|
|
||||||
virtual ~AcceleratorPluginInterface() = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// `AcceleratorPluginInterface` implemented for `TfLiteDelegate`.
|
|
||||||
using DelegatePluginInterface =
|
|
||||||
AcceleratorPluginInterface<TfLiteDelegate, TfLiteDelegatePtr>;
|
|
||||||
// `AcceleratorPluginInterface` implemented for `TfLiteExternalContext`.
|
|
||||||
using ContextPluginInterface =
|
|
||||||
AcceleratorPluginInterface<TfLiteExternalContext, TfLiteExternalContextPtr>;
|
|
||||||
|
|
||||||
// A stripped-down registry that allows delegate plugins to be created by name.
|
// A stripped-down registry that allows delegate plugins to be created by name.
|
||||||
//
|
//
|
||||||
// Limitations:
|
// Limitations:
|
||||||
// - Doesn't allow deregistration.
|
// - Doesn't allow deregistration.
|
||||||
// - Doesn't check for duplication registration.
|
// - Doesn't check for duplication registration.
|
||||||
//
|
//
|
||||||
template <typename AcceleratorType, typename AcceleratorPtrType>
|
class DelegatePluginRegistry {
|
||||||
class AcceleratorRegistry {
|
|
||||||
public:
|
public:
|
||||||
typedef std::function<std::unique_ptr<AcceleratorPluginInterface<
|
typedef std::function<std::unique_ptr<DelegatePluginInterface>(
|
||||||
AcceleratorType, AcceleratorPtrType>>(const TFLiteSettings&)>
|
const TFLiteSettings&)>
|
||||||
CreatorFunction;
|
CreatorFunction;
|
||||||
// Returns a AcceleratorPluginInterface registered with `name` or nullptr if
|
// Returns a DelegatePluginInterface registered with `name` or nullptr if no
|
||||||
// no matching plugin found. TFLiteSettings is per-plugin, so that the
|
// matching plugin found.
|
||||||
// corresponding delegate options data lifetime is maintained.
|
// TFLiteSettings is per-plugin, so that the corresponding delegate options
|
||||||
static std::unique_ptr<
|
// data lifetime is maintained.
|
||||||
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
|
static std::unique_ptr<DelegatePluginInterface> CreateByName(
|
||||||
CreateByName(const std::string& name, const TFLiteSettings& settings) {
|
const std::string& name, const TFLiteSettings& settings);
|
||||||
auto* const instance = AcceleratorRegistry::GetSingleton();
|
|
||||||
return instance->CreateImpl(name, settings);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Struct to be statically allocated for registration.
|
// Struct to be statically allocated for registration.
|
||||||
struct Register {
|
struct Register {
|
||||||
Register(const std::string& name, CreatorFunction creator_function) {
|
Register(const std::string& name, CreatorFunction creator_function);
|
||||||
auto* const instance = AcceleratorRegistry::GetSingleton();
|
|
||||||
instance->RegisterImpl(name, creator_function);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void RegisterImpl(const std::string& name, CreatorFunction creator_function) {
|
void RegisterImpl(const std::string& name, CreatorFunction creator_function);
|
||||||
absl::MutexLock lock(&mutex_);
|
std::unique_ptr<DelegatePluginInterface> CreateImpl(
|
||||||
factories_[name] = creator_function;
|
const std::string& name, const TFLiteSettings& settings);
|
||||||
}
|
static DelegatePluginRegistry* GetSingleton();
|
||||||
|
|
||||||
std::unique_ptr<
|
|
||||||
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
|
|
||||||
CreateImpl(const std::string& name, const TFLiteSettings& settings) {
|
|
||||||
absl::MutexLock lock(&mutex_);
|
|
||||||
auto it = factories_.find(name);
|
|
||||||
return (it != factories_.end()) ? it->second(settings) : nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static AcceleratorRegistry* GetSingleton() {
|
|
||||||
static auto* instance = new AcceleratorRegistry();
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
std::unordered_map<std::string, CreatorFunction> factories_
|
std::unordered_map<std::string, CreatorFunction> factories_
|
||||||
ABSL_GUARDED_BY(mutex_);
|
ABSL_GUARDED_BY(mutex_);
|
||||||
};
|
};
|
||||||
|
|
||||||
using DelegatePluginRegistry =
|
|
||||||
AcceleratorRegistry<TfLiteDelegate, TfLiteDelegatePtr>;
|
|
||||||
using ContextPluginRegistry =
|
|
||||||
AcceleratorRegistry<TfLiteExternalContext, TfLiteExternalContextPtr>;
|
|
||||||
|
|
||||||
} // namespace delegates
|
} // namespace delegates
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#define TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
|
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f) \
|
||||||
name, f, accelerator_type, accelerator_ptr_type) \
|
static auto* g_delegate_plugin_##name##_ = \
|
||||||
static auto* g_delegate_plugin_##name##_ = \
|
new DelegatePluginRegistry::Register(#name, f);
|
||||||
new AcceleratorRegistry<accelerator_type, \
|
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
|
||||||
accelerator_ptr_type>::Register(#name, f);
|
TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f);
|
||||||
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
|
|
||||||
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME(name, f, TfLiteDelegate, \
|
|
||||||
TfLiteDelegatePtr);
|
|
||||||
#define TFLITE_REGISTER_EXTERNAL_CONTEXT_FACTORY_FUNCTION(name, f) \
|
|
||||||
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
|
|
||||||
name, f, TfLiteExternalContext, TfLiteExternalContextPtr);
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user