From 7f79b013c30219b1db19540939d917a49b1c31ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Tue, 17 Nov 2020 14:19:11 +0100 Subject: [PATCH] TFLu: Move Ethos-U custom op out of AllOpsResolver --- tensorflow/lite/micro/all_ops_resolver.cc | 16 +------------ .../lite/micro/kernels/ethos-u/README.md | 8 +++---- .../lite/micro/kernels/ethos-u/ethosu.cc | 24 +++++++------------ tensorflow/lite/micro/kernels/ethosu.cc | 7 +----- .../lite/micro/micro_mutable_op_resolver.h | 10 ++++++++ 5 files changed, 25 insertions(+), 40 deletions(-) diff --git a/tensorflow/lite/micro/all_ops_resolver.cc b/tensorflow/lite/micro/all_ops_resolver.cc index 0a2a0c0f7fe..b538708d309 100644 --- a/tensorflow/lite/micro/all_ops_resolver.cc +++ b/tensorflow/lite/micro/all_ops_resolver.cc @@ -15,14 +15,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/micro_ops.h" namespace tflite { -namespace ops { -namespace micro { -namespace custom { -TfLiteRegistration* Register_ETHOSU(); -const char* GetString_ETHOSU(); -} // namespace custom -} // namespace micro -} // namespace ops AllOpsResolver::AllOpsResolver() { // Please keep this list of Builtin Operators in alphabetical order. @@ -38,6 +30,7 @@ AllOpsResolver::AllOpsResolver() { AddDepthwiseConv2D(); AddDequantize(); AddEqual(); + AddEthosU(); AddFloor(); AddFullyConnected(); AddGreater(); @@ -82,13 +75,6 @@ AllOpsResolver::AllOpsResolver() { AddSvdf(); AddTanh(); AddUnpack(); - - // TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver. - TfLiteRegistration* registration = - tflite::ops::micro::custom::Register_ETHOSU(); - if (registration) { - AddCustom(tflite::ops::micro::custom::GetString_ETHOSU(), registration); - } } } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/ethos-u/README.md b/tensorflow/lite/micro/kernels/ethos-u/README.md index becf270e4a0..8deb8f271d8 100644 --- a/tensorflow/lite/micro/kernels/ethos-u/README.md +++ b/tensorflow/lite/micro/kernels/ethos-u/README.md @@ -19,10 +19,10 @@ required as compiler as well. | tensor1 | v -+---------+ -| softmax | -| | -+----|----+ ++-----------+ +| transpose | +| | ++----|------+ | | tensor2 | diff --git a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc index 05b44714773..122c72dd1c0 100644 --- a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc +++ b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc @@ -17,13 +17,10 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/tools/make/downloads/flatbuffers/include/flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" namespace tflite { -namespace ops { -namespace micro { -namespace custom { -namespace ethosu { +namespace { constexpr uint8_t CO_TYPE_ETHOSU = 1; @@ -93,7 +90,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tensor = context->GetEvalTensor(context, node->inputs->data[0]); cms_data = reinterpret_cast(tensor->data.uint8); - // Get adresses to weights/scratch/input data + // Get addresses to weights/scratch/input data for (i = 1; i < node->inputs->size; ++i) { tensor = context->GetEvalTensor(context, node->inputs->data[i]); base_addrs[num_tensors] = reinterpret_cast(tensor->data.uint8); @@ -101,7 +98,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { num_tensors++; } - // Get adresses to output data + // Get addresses to output data for (i = 0; i < node->outputs->size; ++i) { tensor = context->GetEvalTensor(context, node->outputs->data[i]); base_addrs[num_tensors] = reinterpret_cast(tensor->data.uint8); @@ -122,13 +119,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } -} // namespace ethosu +} // namespace TfLiteRegistration* Register_ETHOSU() { - static TfLiteRegistration r = {ethosu::Init, - ethosu::Free, - ethosu::Prepare, - ethosu::Eval, + static TfLiteRegistration r = {Init, + Free, + Prepare, + Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, @@ -138,7 +135,4 @@ TfLiteRegistration* Register_ETHOSU() { const char* GetString_ETHOSU() { return "ethos-u"; } -} // namespace custom -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/ethosu.cc b/tensorflow/lite/micro/kernels/ethosu.cc index eac6cea8324..c305121e87f 100644 --- a/tensorflow/lite/micro/kernels/ethosu.cc +++ b/tensorflow/lite/micro/kernels/ethosu.cc @@ -19,14 +19,9 @@ limitations under the License. #include "tensorflow/lite/c/common.h" namespace tflite { -namespace ops { -namespace micro { -namespace custom { + TfLiteRegistration* Register_ETHOSU() { return nullptr; } const char* GetString_ETHOSU() { return ""; } -} // namespace custom -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 0175c8dbd6a..ef5e7a77f02 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -30,6 +30,8 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { +TfLiteRegistration* Register_ETHOSU(); +const char* GetString_ETHOSU(); template class MicroMutableOpResolver : public MicroOpResolver { @@ -175,6 +177,14 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::ops::micro::Register_EQUAL(), ParseEqual); } + TfLiteStatus AddEthosU() { + TfLiteRegistration* registration = tflite::Register_ETHOSU(); + if (registration) { + return AddCustom(tflite::GetString_ETHOSU(), registration); + } + return kTfLiteOk; + } + TfLiteStatus AddFloor() { return AddBuiltin(BuiltinOperator_FLOOR, tflite::ops::micro::Register_FLOOR(), ParseFloor);