TFLu: Move Ethos-U custom op out of AllOpsResolver

This commit is contained in:
Måns Nilsson 2020-11-17 14:19:11 +01:00
parent b858de3779
commit 7f79b013c3
5 changed files with 25 additions and 40 deletions

View File

@ -15,14 +15,6 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/micro_ops.h"
namespace tflite {
namespace ops {
namespace micro {
namespace custom {
TfLiteRegistration* Register_ETHOSU();
const char* GetString_ETHOSU();
} // namespace custom
} // namespace micro
} // namespace ops
AllOpsResolver::AllOpsResolver() {
// Please keep this list of Builtin Operators in alphabetical order.
@ -38,6 +30,7 @@ AllOpsResolver::AllOpsResolver() {
AddDepthwiseConv2D();
AddDequantize();
AddEqual();
AddEthosU();
AddFloor();
AddFullyConnected();
AddGreater();
@ -82,13 +75,6 @@ AllOpsResolver::AllOpsResolver() {
AddSvdf();
AddTanh();
AddUnpack();
// TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver.
TfLiteRegistration* registration =
tflite::ops::micro::custom::Register_ETHOSU();
if (registration) {
AddCustom(tflite::ops::micro::custom::GetString_ETHOSU(), registration);
}
}
} // namespace tflite

View File

@ -19,10 +19,10 @@ required as compiler as well.
| tensor1
|
v
+---------+
| softmax |
| |
+----|----+
+-----------+
| transpose |
| |
+----|------+
|
| tensor2
|

View File

@ -17,13 +17,10 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/tools/make/downloads/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "flatbuffers/flexbuffers.h"
namespace tflite {
namespace ops {
namespace micro {
namespace custom {
namespace ethosu {
namespace {
constexpr uint8_t CO_TYPE_ETHOSU = 1;
@ -93,7 +90,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tensor = context->GetEvalTensor(context, node->inputs->data[0]);
cms_data = reinterpret_cast<void*>(tensor->data.uint8);
// Get adresses to weights/scratch/input data
// Get addresses to weights/scratch/input data
for (i = 1; i < node->inputs->size; ++i) {
tensor = context->GetEvalTensor(context, node->inputs->data[i]);
base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
@ -101,7 +98,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
num_tensors++;
}
// Get adresses to output data
// Get addresses to output data
for (i = 0; i < node->outputs->size; ++i) {
tensor = context->GetEvalTensor(context, node->outputs->data[i]);
base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
@ -122,13 +119,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
} // namespace ethosu
} // namespace
TfLiteRegistration* Register_ETHOSU() {
static TfLiteRegistration r = {ethosu::Init,
ethosu::Free,
ethosu::Prepare,
ethosu::Eval,
static TfLiteRegistration r = {Init,
Free,
Prepare,
Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
@ -138,7 +135,4 @@ TfLiteRegistration* Register_ETHOSU() {
const char* GetString_ETHOSU() { return "ethos-u"; }
} // namespace custom
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -19,14 +19,9 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace custom {
TfLiteRegistration* Register_ETHOSU() { return nullptr; }
const char* GetString_ETHOSU() { return ""; }
} // namespace custom
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -30,6 +30,8 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
TfLiteRegistration* Register_ETHOSU();
const char* GetString_ETHOSU();
template <unsigned int tOpCount>
class MicroMutableOpResolver : public MicroOpResolver {
@ -175,6 +177,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_EQUAL(), ParseEqual);
}
TfLiteStatus AddEthosU() {
TfLiteRegistration* registration = tflite::Register_ETHOSU();
if (registration) {
return AddCustom(tflite::GetString_ETHOSU(), registration);
}
return kTfLiteOk;
}
TfLiteStatus AddFloor() {
return AddBuiltin(BuiltinOperator_FLOOR,
tflite::ops::micro::Register_FLOOR(), ParseFloor);