Reimplement Coral plugin in Acceleration@Scale using delegate

PiperOrigin-RevId: 355061584
Change-Id: I9ba1837fcd73356a990fb647edd16747f6eb0a48
This commit is contained in:
Lu Wang 2021-02-01 17:24:21 -08:00 committed by TensorFlower Gardener
parent f3e7ae3965
commit 636990dd7a
4 changed files with 84 additions and 83 deletions

View File

@ -96,12 +96,11 @@ cc_library(
cc_library( cc_library(
name = "delegate_registry", name = "delegate_registry",
srcs = ["delegate_registry.cc"],
hdrs = ["delegate_registry.h"], hdrs = ["delegate_registry.h"],
deps = [ deps = [
":configuration_fbs", ":configuration_fbs",
"//tensorflow/lite:mutable_op_resolver",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
) )

View File

@ -44,20 +44,16 @@ enum ExecutionPreference {
FORCE_CPU = 3; FORCE_CPU = 3;
} }
// TFLite accelerator to use. It can be either a delegate or an external // TFLite accelerator to use.
// context.
enum Delegate { enum Delegate {
NONE = 0; NONE = 0;
// DELEGATE OPTIONS.
NNAPI = 1; NNAPI = 1;
GPU = 2; GPU = 2;
HEXAGON = 3; HEXAGON = 3;
XNNPACK = 4; XNNPACK = 4;
// The EdgeTpu in Pixel devices. // The EdgeTpu in Pixel devices.
EDGETPU = 5; EDGETPU = 5;
// EXTERNAL CONTEXT OPTIONS.
// The Coral EdgeTpu Dev Board / USB accelerator. // The Coral EdgeTpu Dev Board / USB accelerator.
EDGETPU_CORAL = 6; EDGETPU_CORAL = 6;
} }
@ -263,10 +259,10 @@ message EdgeTpuSettings {
optional EdgeTpuDeviceSpec edgetpu_device_spec = 4; optional EdgeTpuDeviceSpec edgetpu_device_spec = 4;
} }
// Coral Dev Board / USB accelerator external context settings. // Coral Dev Board / USB accelerator delegate settings.
// //
// See // See
// https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu.h // https://github.com/google-coral/edgetpu/blob/master/libedgetpu/edgetpu_c.h
message CoralSettings { message CoralSettings {
enum Performance { enum Performance {
UNDEFINED = 0; UNDEFINED = 0;
@ -320,7 +316,7 @@ message TFLiteSettings {
// For configuring the EdgeTpuDelegate. // For configuring the EdgeTpuDelegate.
optional EdgeTpuSettings edgetpu_settings = 8; optional EdgeTpuSettings edgetpu_settings = 8;
// For configuring the Coral External Context (EdgeTpuContext). // For configuring the Coral EdgeTpu Delegate.
optional CoralSettings coral_settings = 10; optional CoralSettings coral_settings = 10;
// Whether to automatically fall back to TFLite CPU path. // Whether to automatically fall back to TFLite CPU path.

View File

@ -0,0 +1,56 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
#include "absl/synchronization/mutex.h"
namespace tflite {
namespace delegates {
void DelegatePluginRegistry::RegisterImpl(
const std::string& name,
std::function<
std::unique_ptr<DelegatePluginInterface>(const TFLiteSettings&)>
creator_function) {
absl::MutexLock lock(&mutex_);
factories_[name] = creator_function;
}
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateImpl(
const std::string& name, const TFLiteSettings& settings) {
absl::MutexLock lock(&mutex_);
auto it = factories_.find(name);
return (it != factories_.end()) ? it->second(settings) : nullptr;
}
DelegatePluginRegistry* DelegatePluginRegistry::GetSingleton() {
static auto* instance = new DelegatePluginRegistry();
return instance;
}
std::unique_ptr<DelegatePluginInterface> DelegatePluginRegistry::CreateByName(
const std::string& name, const TFLiteSettings& settings) {
auto* const instance = DelegatePluginRegistry::GetSingleton();
return instance->CreateImpl(name, settings);
}
DelegatePluginRegistry::Register::Register(const std::string& name,
CreatorFunction creator_function) {
auto* const instance = DelegatePluginRegistry::GetSingleton();
instance->RegisterImpl(name, creator_function);
}
} // namespace delegates
} // namespace tflite

View File

@ -18,11 +18,9 @@ limitations under the License.
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
#include "tensorflow/lite/mutable_op_resolver.h"
// Defines an interface for TFLite delegate plugins. // Defines an interface for TFLite delegate plugins.
// //
@ -46,101 +44,53 @@ namespace delegates {
using TfLiteDelegatePtr = using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// A shared pointer to `TfLiteExternalContext`, similar to `TfLiteDelegatePtr`. class DelegatePluginInterface {
using TfLiteExternalContextPtr = std::shared_ptr<TfLiteExternalContext>;
template <typename AcceleratorType, typename AcceleratorPtrType>
class AcceleratorPluginInterface {
public: public:
virtual AcceleratorPtrType Create() = 0; virtual TfLiteDelegatePtr Create() = 0;
// Some accelerators require their own custom ops, such as the Coral plugin. virtual int GetDelegateErrno(TfLiteDelegate* from_delegate) = 0;
// Default to an empty MutableOpResolver. virtual ~DelegatePluginInterface() = default;
virtual std::unique_ptr<MutableOpResolver> CreateOpResolver() {
return absl::make_unique<MutableOpResolver>();
}
virtual int GetDelegateErrno(AcceleratorType* from_delegate) = 0;
virtual ~AcceleratorPluginInterface() = default;
}; };
// `AcceleratorPluginInterface` implemented for `TfLiteDelegate`.
using DelegatePluginInterface =
AcceleratorPluginInterface<TfLiteDelegate, TfLiteDelegatePtr>;
// `AcceleratorPluginInterface` implemented for `TfLiteExternalContext`.
using ContextPluginInterface =
AcceleratorPluginInterface<TfLiteExternalContext, TfLiteExternalContextPtr>;
// A stripped-down registry that allows delegate plugins to be created by name. // A stripped-down registry that allows delegate plugins to be created by name.
// //
// Limitations: // Limitations:
// - Doesn't allow deregistration. // - Doesn't allow deregistration.
// - Doesn't check for duplication registration. // - Doesn't check for duplication registration.
// //
template <typename AcceleratorType, typename AcceleratorPtrType> class DelegatePluginRegistry {
class AcceleratorRegistry {
public: public:
typedef std::function<std::unique_ptr<AcceleratorPluginInterface< typedef std::function<std::unique_ptr<DelegatePluginInterface>(
AcceleratorType, AcceleratorPtrType>>(const TFLiteSettings&)> const TFLiteSettings&)>
CreatorFunction; CreatorFunction;
// Returns a AcceleratorPluginInterface registered with `name` or nullptr if // Returns a DelegatePluginInterface registered with `name` or nullptr if no
// no matching plugin found. TFLiteSettings is per-plugin, so that the // matching plugin found.
// corresponding delegate options data lifetime is maintained. // TFLiteSettings is per-plugin, so that the corresponding delegate options
static std::unique_ptr< // data lifetime is maintained.
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>> static std::unique_ptr<DelegatePluginInterface> CreateByName(
CreateByName(const std::string& name, const TFLiteSettings& settings) { const std::string& name, const TFLiteSettings& settings);
auto* const instance = AcceleratorRegistry::GetSingleton();
return instance->CreateImpl(name, settings);
}
// Struct to be statically allocated for registration. // Struct to be statically allocated for registration.
struct Register { struct Register {
Register(const std::string& name, CreatorFunction creator_function) { Register(const std::string& name, CreatorFunction creator_function);
auto* const instance = AcceleratorRegistry::GetSingleton();
instance->RegisterImpl(name, creator_function);
}
}; };
private: private:
void RegisterImpl(const std::string& name, CreatorFunction creator_function) { void RegisterImpl(const std::string& name, CreatorFunction creator_function);
absl::MutexLock lock(&mutex_); std::unique_ptr<DelegatePluginInterface> CreateImpl(
factories_[name] = creator_function; const std::string& name, const TFLiteSettings& settings);
} static DelegatePluginRegistry* GetSingleton();
std::unique_ptr<
AcceleratorPluginInterface<AcceleratorType, AcceleratorPtrType>>
CreateImpl(const std::string& name, const TFLiteSettings& settings) {
absl::MutexLock lock(&mutex_);
auto it = factories_.find(name);
return (it != factories_.end()) ? it->second(settings) : nullptr;
}
static AcceleratorRegistry* GetSingleton() {
static auto* instance = new AcceleratorRegistry();
return instance;
}
absl::Mutex mutex_; absl::Mutex mutex_;
std::unordered_map<std::string, CreatorFunction> factories_ std::unordered_map<std::string, CreatorFunction> factories_
ABSL_GUARDED_BY(mutex_); ABSL_GUARDED_BY(mutex_);
}; };
using DelegatePluginRegistry =
AcceleratorRegistry<TfLiteDelegate, TfLiteDelegatePtr>;
using ContextPluginRegistry =
AcceleratorRegistry<TfLiteExternalContext, TfLiteExternalContextPtr>;
} // namespace delegates } // namespace delegates
} // namespace tflite } // namespace tflite
#define TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \ #define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f) \
name, f, accelerator_type, accelerator_ptr_type) \ static auto* g_delegate_plugin_##name##_ = \
static auto* g_delegate_plugin_##name##_ = \ new DelegatePluginRegistry::Register(#name, f);
new AcceleratorRegistry<accelerator_type, \ #define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
accelerator_ptr_type>::Register(#name, f); TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION_VNAME(name, f);
#define TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(name, f) \
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME(name, f, TfLiteDelegate, \
TfLiteDelegatePtr);
#define TFLITE_REGISTER_EXTERNAL_CONTEXT_FACTORY_FUNCTION(name, f) \
TFLITE_REGISTER_ACCELERATOR_FACTORY_FUNCTION_VNAME( \
name, f, TfLiteExternalContext, TfLiteExternalContextPtr);
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_ #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_