TFLu: Move Ethos-U custom op out of AllOpsResolver
This commit is contained in:
parent
b858de3779
commit
7f79b013c3
@ -15,14 +15,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace custom {
|
||||
TfLiteRegistration* Register_ETHOSU();
|
||||
const char* GetString_ETHOSU();
|
||||
} // namespace custom
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
|
||||
AllOpsResolver::AllOpsResolver() {
|
||||
// Please keep this list of Builtin Operators in alphabetical order.
|
||||
@ -38,6 +30,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddDepthwiseConv2D();
|
||||
AddDequantize();
|
||||
AddEqual();
|
||||
AddEthosU();
|
||||
AddFloor();
|
||||
AddFullyConnected();
|
||||
AddGreater();
|
||||
@ -82,13 +75,6 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddSvdf();
|
||||
AddTanh();
|
||||
AddUnpack();
|
||||
|
||||
// TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver.
|
||||
TfLiteRegistration* registration =
|
||||
tflite::ops::micro::custom::Register_ETHOSU();
|
||||
if (registration) {
|
||||
AddCustom(tflite::ops::micro::custom::GetString_ETHOSU(), registration);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -19,10 +19,10 @@ required as compiler as well.
|
||||
| tensor1
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| softmax |
|
||||
| |
|
||||
+----|----+
|
||||
+-----------+
|
||||
| transpose |
|
||||
| |
|
||||
+----|------+
|
||||
|
|
||||
| tensor2
|
||||
|
|
||||
|
@ -17,13 +17,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/tools/make/downloads/flatbuffers/include/flatbuffers/flexbuffers.h"
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace custom {
|
||||
namespace ethosu {
|
||||
namespace {
|
||||
|
||||
constexpr uint8_t CO_TYPE_ETHOSU = 1;
|
||||
|
||||
@ -93,7 +90,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tensor = context->GetEvalTensor(context, node->inputs->data[0]);
|
||||
cms_data = reinterpret_cast<void*>(tensor->data.uint8);
|
||||
|
||||
// Get adresses to weights/scratch/input data
|
||||
// Get addresses to weights/scratch/input data
|
||||
for (i = 1; i < node->inputs->size; ++i) {
|
||||
tensor = context->GetEvalTensor(context, node->inputs->data[i]);
|
||||
base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
|
||||
@ -101,7 +98,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
num_tensors++;
|
||||
}
|
||||
|
||||
// Get adresses to output data
|
||||
// Get addresses to output data
|
||||
for (i = 0; i < node->outputs->size; ++i) {
|
||||
tensor = context->GetEvalTensor(context, node->outputs->data[i]);
|
||||
base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
|
||||
@ -122,13 +119,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ethosu
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration* Register_ETHOSU() {
|
||||
static TfLiteRegistration r = {ethosu::Init,
|
||||
ethosu::Free,
|
||||
ethosu::Prepare,
|
||||
ethosu::Eval,
|
||||
static TfLiteRegistration r = {Init,
|
||||
Free,
|
||||
Prepare,
|
||||
Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
@ -138,7 +135,4 @@ TfLiteRegistration* Register_ETHOSU() {
|
||||
|
||||
const char* GetString_ETHOSU() { return "ethos-u"; }
|
||||
|
||||
} // namespace custom
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -19,14 +19,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_ETHOSU() { return nullptr; }
|
||||
|
||||
const char* GetString_ETHOSU() { return ""; }
|
||||
|
||||
} // namespace custom
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -30,6 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
TfLiteRegistration* Register_ETHOSU();
|
||||
const char* GetString_ETHOSU();
|
||||
|
||||
template <unsigned int tOpCount>
|
||||
class MicroMutableOpResolver : public MicroOpResolver {
|
||||
@ -175,6 +177,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_EQUAL(), ParseEqual);
|
||||
}
|
||||
|
||||
TfLiteStatus AddEthosU() {
|
||||
TfLiteRegistration* registration = tflite::Register_ETHOSU();
|
||||
if (registration) {
|
||||
return AddCustom(tflite::GetString_ETHOSU(), registration);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus AddFloor() {
|
||||
return AddBuiltin(BuiltinOperator_FLOOR,
|
||||
tflite::ops::micro::Register_FLOOR(), ParseFloor);
|
||||
|
Loading…
x
Reference in New Issue
Block a user