Remove deprecated AddBuiltin API from MicroMutableOpResolver.

* Added new API hooks for all the OPs currently supported in TFLM.
* These new APIs still need to be implemented with operator specific parse
  functions but this change allows us to remove the old API and incrementally
  update the implementations.

PiperOrigin-RevId: 317770205
Change-Id: Idaaa687401f2bac5fbf9925e27c04bf536b154ea
This commit is contained in:
Advait Jain 2020-06-22 17:39:59 -07:00 committed by TensorFlower Gardener
parent fe6e64b098
commit 34b4fab30a
13 changed files with 408 additions and 214 deletions

View File

@ -26,74 +26,60 @@ const char* GetString_ETHOSU();
AllOpsResolver::AllOpsResolver() { AllOpsResolver::AllOpsResolver() {
// Please keep this list of Builtin Operators in alphabetical order. // Please keep this list of Builtin Operators in alphabetical order.
AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS()); AddAbs();
AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD()); AddAdd();
AddBuiltin(BuiltinOperator_ARG_MAX, tflite::ops::micro::Register_ARG_MAX()); AddArgMax();
AddBuiltin(BuiltinOperator_ARG_MIN, tflite::ops::micro::Register_ARG_MIN()); AddArgMin();
AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, AddAveragePool2D();
tflite::ops::micro::Register_AVERAGE_POOL_2D()); AddCeil();
AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL()); AddConcatenation();
AddBuiltin(BuiltinOperator_CONCATENATION, AddConv2D();
tflite::ops::micro::Register_CONCATENATION()); AddCos();
AddBuiltin(BuiltinOperator_CONV_2D, tflite::ops::micro::Register_CONV_2D()); AddDepthwiseConv2D();
AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS()); AddDequantize();
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, AddEqual();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); AddFloor();
AddBuiltin(BuiltinOperator_DEQUANTIZE, AddFullyConnected();
tflite::ops::micro::Register_DEQUANTIZE()); AddGreater();
AddBuiltin(BuiltinOperator_EQUAL, tflite::ops::micro::Register_EQUAL()); AddGreaterEqual();
AddBuiltin(BuiltinOperator_FLOOR, tflite::ops::micro::Register_FLOOR()); AddL2Normalization();
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, AddLess();
tflite::ops::micro::Register_FULLY_CONNECTED()); AddLessEqual();
AddBuiltin(BuiltinOperator_GREATER, tflite::ops::micro::Register_GREATER()); AddLog();
AddBuiltin(BuiltinOperator_GREATER_EQUAL, AddLogicalAnd();
tflite::ops::micro::Register_GREATER_EQUAL()); AddLogicalNot();
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, AddLogicalOr();
tflite::ops::micro::Register_L2_NORMALIZATION()); AddLogistic();
AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS()); AddMaximum();
AddBuiltin(BuiltinOperator_LESS_EQUAL, AddMaxPool2D();
tflite::ops::micro::Register_LESS_EQUAL()); AddMean();
AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG()); AddMinimum();
AddBuiltin(BuiltinOperator_LOGICAL_AND, AddMul();
tflite::ops::micro::Register_LOGICAL_AND()); AddNeg();
AddBuiltin(BuiltinOperator_LOGICAL_NOT, AddNotEqual();
tflite::ops::micro::Register_LOGICAL_NOT()); AddPack();
AddBuiltin(BuiltinOperator_LOGICAL_OR, AddPad();
tflite::ops::micro::Register_LOGICAL_OR()); AddPadV2();
AddBuiltin(BuiltinOperator_LOGISTIC, tflite::ops::micro::Register_LOGISTIC()); AddPrelu();
AddBuiltin(BuiltinOperator_MAX_POOL_2D, AddQuantize();
tflite::ops::micro::Register_MAX_POOL_2D()); AddRelu();
AddBuiltin(BuiltinOperator_MAXIMUM, tflite::ops::micro::Register_MAXIMUM()); AddRelu6();
AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN()); AddReshape();
AddBuiltin(BuiltinOperator_MINIMUM, tflite::ops::micro::Register_MINIMUM()); AddResizeNearestNeighbor();
AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL()); AddRound();
AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG()); AddRsqrt();
AddBuiltin(BuiltinOperator_NOT_EQUAL, AddSin();
tflite::ops::micro::Register_NOT_EQUAL()); AddSoftmax();
AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK()); AddSplit();
AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD()); AddSqrt();
AddBuiltin(BuiltinOperator_PADV2, tflite::ops::micro::Register_PADV2()); AddSquare();
AddBuiltin(BuiltinOperator_PRELU, tflite::ops::micro::Register_PRELU()); AddStridedSlice();
AddBuiltin(BuiltinOperator_QUANTIZE, tflite::ops::micro::Register_QUANTIZE()); AddSub();
AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU()); AddSvdf();
AddBuiltin(BuiltinOperator_RELU6, tflite::ops::micro::Register_RELU6()); AddTanh();
AddBuiltin(BuiltinOperator_RESHAPE, tflite::ops::micro::Register_RESHAPE()); AddUnpack();
AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR());
AddBuiltin(BuiltinOperator_ROUND, tflite::ops::micro::Register_ROUND());
AddBuiltin(BuiltinOperator_RSQRT, tflite::ops::micro::Register_RSQRT());
AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN());
AddBuiltin(BuiltinOperator_SOFTMAX, tflite::ops::micro::Register_SOFTMAX());
AddBuiltin(BuiltinOperator_SPLIT, tflite::ops::micro::Register_SPLIT());
AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT());
AddBuiltin(BuiltinOperator_SQUARE, tflite::ops::micro::Register_SQUARE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE,
tflite::ops::micro::Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB());
AddBuiltin(BuiltinOperator_SVDF, tflite::ops::micro::Register_SVDF());
AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH());
AddBuiltin(BuiltinOperator_UNPACK, tflite::ops::micro::Register_UNPACK());
// TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver.
TfLiteRegistration* registration = TfLiteRegistration* registration =
tflite::ops::micro::custom::Register_ETHOSU(); tflite::ops::micro::custom::Register_ETHOSU();
if (registration) { if (registration) {

View File

@ -44,14 +44,10 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) {
tflite::MicroMutableOpResolver<4> micro_op_resolver; tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_CONV_2D()); micro_op_resolver.AddMaxPool2D();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, micro_op_resolver.AddFullyConnected();
tflite::ops::micro::Register_MAX_POOL_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
const int tensor_arena_size = 50 * 1024; const int tensor_arena_size = 50 * 1024;
uint8_t tensor_arena[tensor_arena_size]; uint8_t tensor_arena[tensor_arena_size];

View File

@ -58,14 +58,10 @@ int main(int argc, char** argv) {
tflite::MicroMutableOpResolver<4> micro_op_resolver; tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_CONV_2D()); micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, micro_op_resolver.AddMaxPool2D();
tflite::ops::micro::Register_MAX_POOL_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
constexpr int tensor_arena_size = 50 * 1024; constexpr int tensor_arena_size = 50 * 1024;
uint8_t tensor_arena[tensor_arena_size]; uint8_t tensor_arena[tensor_arena_size];

View File

@ -47,17 +47,11 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
// incur some penalty in code space for op implementations that are not // incur some penalty in code space for op implementations that are not
// needed by this graph. // needed by this graph.
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
micro_op_resolver.AddBuiltin( micro_op_resolver.AddConv2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddDepthwiseConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, micro_op_resolver.AddMaxPool2D();
tflite::ops::micro::Register_MAX_POOL_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
// Create an area of memory to use for input, output, and intermediate arrays. // Create an area of memory to use for input, output, and intermediate arrays.
// Finding the minimum value for your model may require some trial and error. // Finding the minimum value for your model may require some trial and error.

View File

@ -66,17 +66,11 @@ void setup() {
// incur some penalty in code space for op implementations that are not // incur some penalty in code space for op implementations that are not
// needed by this graph. // needed by this graph.
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
micro_op_resolver.AddBuiltin( micro_op_resolver.AddConv2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddDepthwiseConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, micro_op_resolver.AddMaxPool2D();
tflite::ops::micro::Register_MAX_POOL_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
// Build an interpreter to run the model with. // Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter( static tflite::MicroInterpreter static_interpreter(

View File

@ -75,24 +75,16 @@ void setup() {
// tflite::AllOpsResolver resolver; // tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables) // NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter); static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
if (micro_op_resolver.AddBuiltin( if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()) != kTfLiteOk) {
return; return;
} }
if (micro_op_resolver.AddBuiltin( if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED()) != kTfLiteOk) {
return; return;
} }
if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
tflite::ops::micro::Register_SOFTMAX()) !=
kTfLiteOk) {
return; return;
} }
if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, if (micro_op_resolver.AddReshape() != kTfLiteOk) {
tflite::ops::micro::Register_RESHAPE()) !=
kTfLiteOk) {
return; return;
} }

View File

@ -49,15 +49,10 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// //
// tflite::AllOpsResolver resolver; // tflite::AllOpsResolver resolver;
tflite::MicroMutableOpResolver<4> micro_op_resolver; tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin( micro_op_resolver.AddDepthwiseConv2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddFullyConnected();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddReshape();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, micro_op_resolver.AddSoftmax();
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE,
tflite::ops::micro::Register_RESHAPE());
// Create an area of memory to use for input, output, and intermediate arrays. // Create an area of memory to use for input, output, and intermediate arrays.
const int tensor_arena_size = 10 * 1024; const int tensor_arena_size = 10 * 1024;

View File

@ -66,13 +66,9 @@ void setup() {
// tflite::AllOpsResolver resolver; // tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables) // NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<3> micro_op_resolver; static tflite::MicroMutableOpResolver<3> micro_op_resolver;
micro_op_resolver.AddBuiltin( micro_op_resolver.AddAveragePool2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
tflite::ops::micro::Register_AVERAGE_POOL_2D());
// Build an interpreter to run the model with. // Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter( static tflite::MicroInterpreter static_interpreter(

View File

@ -57,13 +57,9 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// //
// tflite::AllOpsResolver resolver; // tflite::AllOpsResolver resolver;
tflite::MicroMutableOpResolver<3> micro_op_resolver; tflite::MicroMutableOpResolver<3> micro_op_resolver;
micro_op_resolver.AddBuiltin( micro_op_resolver.AddAveragePool2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
tflite::ops::micro::Register_AVERAGE_POOL_2D());
// Build an interpreter to run the model with. // Build an interpreter to run the model with.
tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena, tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,

View File

@ -73,17 +73,11 @@ void setup() {
// tflite::AllOpsResolver resolver; // tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables) // NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<5> micro_op_resolver; static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddBuiltin( micro_op_resolver.AddAveragePool2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, micro_op_resolver.AddReshape();
tflite::ops::micro::Register_CONV_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
tflite::ops::micro::Register_AVERAGE_POOL_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE,
tflite::ops::micro::Register_RESHAPE());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
// Build an interpreter to run the model with. // Build an interpreter to run the model with.
// NOLINTNEXTLINE(runtime-global-variables) // NOLINTNEXTLINE(runtime-global-variables)

View File

@ -53,17 +53,11 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// incur some penalty in code space for op implementations that are not // incur some penalty in code space for op implementations that are not
// needed by this graph. // needed by this graph.
tflite::MicroMutableOpResolver<5> micro_op_resolver; tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddBuiltin( micro_op_resolver.AddAveragePool2D();
tflite::BuiltinOperator_DEPTHWISE_CONV_2D, micro_op_resolver.AddConv2D();
tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, micro_op_resolver.AddReshape();
tflite::ops::micro::Register_CONV_2D()); micro_op_resolver.AddSoftmax();
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
tflite::ops::micro::Register_AVERAGE_POOL_2D());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE,
tflite::ops::micro::Register_RESHAPE());
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
// Build an interpreter to run the model with. // Build an interpreter to run the model with.
tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena, tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,

View File

@ -104,39 +104,72 @@ class MicroMutableOpResolver : public MicroOpResolver {
return kTfLiteOk; return kTfLiteOk;
} }
// Registers a Builtin Operator with the MicroOpResolver.
//
// Only the first call for a given BuiltinOperator enum will be successful.
// i.e. if this function is called again for a previously added
// BuiltinOperator, the MicroOpResolver will be unchanged and this function
// will return kTfLiteError.
//
// TODO(b/149408647): remove this API once the BuiltinOperator specific Add
// functions are fully implemented.
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration) {
TFLITE_DCHECK(registration != nullptr);
// For code that is not switched over to the new selective registration of
// the parse function, we pass in ParseOpData. This allows for backwards
// compatibility.
return AddBuiltin(op, *registration, ParseOpData);
}
// The Add* functions below add the various Builtin operators to the // The Add* functions below add the various Builtin operators to the
// MicroMutableOpResolver object. // MicroMutableOpResolver object.
//
// This API is currently experimental (and only supported for a small subset TfLiteStatus AddAbs() {
// of operators). It will soon be preferred over the AddBuiltin function for // TODO(b/149408647): Replace ParseOpData with the operator specific parse
// the following reason: // function.
// * If all calls to AddBuiltin for an application use this API, the code return AddBuiltin(BuiltinOperator_ABS, *tflite::ops::micro::Register_ABS(),
// size will be smaller by 5-8K (compared to the using the AddBuiltin ParseOpData);
// override). }
TfLiteStatus AddAdd() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_ADD, *tflite::ops::micro::Register_ADD(),
ParseOpData);
}
TfLiteStatus AddArgMax() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_ARG_MAX,
*tflite::ops::micro::Register_ARG_MAX(), ParseOpData);
}
TfLiteStatus AddArgMin() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_ARG_MIN,
*tflite::ops::micro::Register_ARG_MIN(), ParseOpData);
}
TfLiteStatus AddAveragePool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
*tflite::ops::micro::Register_AVERAGE_POOL_2D(),
ParseOpData);
}
TfLiteStatus AddCeil() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CEIL,
*tflite::ops::micro::Register_CEIL(), ParseOpData);
}
TfLiteStatus AddConcatenation() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CONCATENATION,
*tflite::ops::micro::Register_CONCATENATION(),
ParseOpData);
}
TfLiteStatus AddConv2D() { TfLiteStatus AddConv2D() {
return AddBuiltin(BuiltinOperator_CONV_2D, return AddBuiltin(BuiltinOperator_CONV_2D,
*tflite::ops::micro::Register_CONV_2D(), ParseConv2D); *tflite::ops::micro::Register_CONV_2D(), ParseConv2D);
} }
TfLiteStatus AddCos() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(),
ParseOpData);
}
TfLiteStatus AddDepthwiseConv2D() { TfLiteStatus AddDepthwiseConv2D() {
return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
*tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), *tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
@ -149,12 +182,91 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseDequantize); ParseDequantize);
} }
TfLiteStatus AddEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_EQUAL,
*tflite::ops::micro::Register_EQUAL(), ParseOpData);
}
TfLiteStatus AddFloor() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_FLOOR,
*tflite::ops::micro::Register_FLOOR(), ParseOpData);
}
TfLiteStatus AddFullyConnected() { TfLiteStatus AddFullyConnected() {
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
*tflite::ops::micro::Register_FULLY_CONNECTED(), *tflite::ops::micro::Register_FULLY_CONNECTED(),
ParseFullyConnected); ParseFullyConnected);
} }
TfLiteStatus AddGreater() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER,
*tflite::ops::micro::Register_GREATER(), ParseOpData);
}
TfLiteStatus AddGreaterEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
*tflite::ops::micro::Register_GREATER_EQUAL(),
ParseOpData);
}
TfLiteStatus AddL2Normalization() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
*tflite::ops::micro::Register_L2_NORMALIZATION(),
ParseOpData);
}
TfLiteStatus AddLess() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LESS,
*tflite::ops::micro::Register_LESS(), ParseOpData);
}
TfLiteStatus AddLessEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LESS_EQUAL,
*tflite::ops::micro::Register_LESS_EQUAL(), ParseOpData);
}
TfLiteStatus AddLog() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOG, *tflite::ops::micro::Register_LOG(),
ParseOpData);
}
TfLiteStatus AddLogicalAnd() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_AND,
*tflite::ops::micro::Register_LOGICAL_AND(), ParseOpData);
}
TfLiteStatus AddLogicalNot() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_NOT,
*tflite::ops::micro::Register_LOGICAL_NOT(), ParseOpData);
}
TfLiteStatus AddLogicalOr() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_OR,
*tflite::ops::micro::Register_LOGICAL_OR(), ParseOpData);
}
TfLiteStatus AddLogistic() { TfLiteStatus AddLogistic() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse // TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function. // function.
@ -162,26 +274,196 @@ class MicroMutableOpResolver : public MicroOpResolver {
*tflite::ops::micro::Register_LOGISTIC(), ParseOpData); *tflite::ops::micro::Register_LOGISTIC(), ParseOpData);
} }
TfLiteStatus AddMaximum() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MAXIMUM,
*tflite::ops::micro::Register_MAXIMUM(), ParseOpData);
}
TfLiteStatus AddMaxPool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
*tflite::ops::micro::Register_MAX_POOL_2D(), ParseOpData);
}
TfLiteStatus AddMean() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MEAN,
*tflite::ops::micro::Register_MEAN(), ParseOpData);
}
TfLiteStatus AddMinimum() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MINIMUM,
*tflite::ops::micro::Register_MINIMUM(), ParseOpData);
}
TfLiteStatus AddMul() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MUL, *tflite::ops::micro::Register_MUL(),
ParseOpData);
}
TfLiteStatus AddNeg() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_NEG, *tflite::ops::micro::Register_NEG(),
ParseOpData);
}
TfLiteStatus AddNotEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_NOT_EQUAL,
*tflite::ops::micro::Register_NOT_EQUAL(), ParseOpData);
}
TfLiteStatus AddPack() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_PACK,
*tflite::ops::micro::Register_PACK(), ParseOpData);
}
TfLiteStatus AddPad() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_PAD, *tflite::ops::micro::Register_PAD(),
ParseOpData);
}
TfLiteStatus AddPadV2() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_PADV2,
*tflite::ops::micro::Register_PADV2(), ParseOpData);
}
TfLiteStatus AddPrelu() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_PRELU,
*tflite::ops::micro::Register_PRELU(), ParseOpData);
}
TfLiteStatus AddQuantize() { TfLiteStatus AddQuantize() {
return AddBuiltin(BuiltinOperator_QUANTIZE, return AddBuiltin(BuiltinOperator_QUANTIZE,
*tflite::ops::micro::Register_QUANTIZE(), ParseQuantize); *tflite::ops::micro::Register_QUANTIZE(), ParseQuantize);
} }
TfLiteStatus AddRelu() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_RELU,
*tflite::ops::micro::Register_RELU(), ParseOpData);
}
TfLiteStatus AddRelu6() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_RELU6,
*tflite::ops::micro::Register_RELU6(), ParseOpData);
}
TfLiteStatus AddReshape() { TfLiteStatus AddReshape() {
return AddBuiltin(BuiltinOperator_RESHAPE, return AddBuiltin(BuiltinOperator_RESHAPE,
*tflite::ops::micro::Register_RESHAPE(), ParseReshape); *tflite::ops::micro::Register_RESHAPE(), ParseReshape);
} }
TfLiteStatus AddResizeNearestNeighbor() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
*tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(),
ParseOpData);
}
TfLiteStatus AddRound() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_ROUND,
*tflite::ops::micro::Register_ROUND(), ParseOpData);
}
TfLiteStatus AddRsqrt() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_RSQRT,
*tflite::ops::micro::Register_RSQRT(), ParseOpData);
}
TfLiteStatus AddSin() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_SIN, *tflite::ops::micro::Register_SIN(),
ParseOpData);
}
TfLiteStatus AddSoftmax() { TfLiteStatus AddSoftmax() {
return AddBuiltin(BuiltinOperator_SOFTMAX, return AddBuiltin(BuiltinOperator_SOFTMAX,
*tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax); *tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax);
} }
TfLiteStatus AddSplit() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_SPLIT,
*tflite::ops::micro::Register_SPLIT(), ParseOpData);
}
TfLiteStatus AddSqrt() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_SQRT,
*tflite::ops::micro::Register_SQRT(), ParseOpData);
}
TfLiteStatus AddSquare() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_SQUARE,
*tflite::ops::micro::Register_SQUARE(), ParseOpData);
}
TfLiteStatus AddStridedSlice() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
*tflite::ops::micro::Register_STRIDED_SLICE(),
ParseOpData);
}
TfLiteStatus AddSub() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_SUB, *tflite::ops::micro::Register_SUB(),
ParseOpData);
}
TfLiteStatus AddSvdf() { TfLiteStatus AddSvdf() {
return AddBuiltin(BuiltinOperator_SVDF, return AddBuiltin(BuiltinOperator_SVDF,
*tflite::ops::micro::Register_SVDF(), ParseSvdf); *tflite::ops::micro::Register_SVDF(), ParseSvdf);
} }
TfLiteStatus AddTanh() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_TANH,
*tflite::ops::micro::Register_TANH(), ParseOpData);
}
TfLiteStatus AddUnpack() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_UNPACK,
*tflite::ops::micro::Register_UNPACK(), ParseOpData);
}
unsigned int GetRegistrationLength() { return registrations_len_; } unsigned int GetRegistrationLength() { return registrations_len_; }
private: private:

View File

@ -68,14 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) {
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke}; tflite::MockPrepare, tflite::MockInvoke};
MicroMutableOpResolver<2> micro_op_resolver; MicroMutableOpResolver<1> micro_op_resolver;
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r));
// Only one AddBuiltin per operator should return kTfLiteOk.
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteError, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
micro_op_resolver.AddCustom("mock_custom", &r)); micro_op_resolver.AddCustom("mock_custom", &r));
@ -85,16 +78,10 @@ TF_LITE_MICRO_TEST(TestOperations) {
tflite::MicroOpResolver* resolver = &micro_op_resolver; tflite::MicroOpResolver* resolver = &micro_op_resolver;
TF_LITE_MICRO_EXPECT_EQ(1, micro_op_resolver.GetRegistrationLength());
const TfLiteRegistration* registration = const TfLiteRegistration* registration =
resolver->FindOp(BuiltinOperator_CONV_2D); resolver->FindOp(BuiltinOperator_RELU);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength());
registration = resolver->FindOp(BuiltinOperator_RELU);
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
registration = resolver->FindOp("mock_custom"); registration = resolver->FindOp("mock_custom");
@ -116,12 +103,7 @@ TF_LITE_MICRO_TEST(TestErrorReporting) {
tflite::MockPrepare, tflite::MockInvoke}; tflite::MockPrepare, tflite::MockInvoke};
tflite::MockErrorReporter mock_reporter; tflite::MockErrorReporter mock_reporter;
MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
mock_reporter.ResetState();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r));
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
mock_reporter.ResetState(); mock_reporter.ResetState();
@ -132,10 +114,7 @@ TF_LITE_MICRO_TEST(TestErrorReporting) {
// Attempting to Add more operators than the class template parameter for // Attempting to Add more operators than the class template parameter for
// MicroMutableOpResolver should result in errors. // MicroMutableOpResolver should result in errors.
TF_LITE_MICRO_EXPECT_EQ( TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddRelu());
kTfLiteError, micro_op_resolver.AddBuiltin(BuiltinOperator_RELU, &r));
TF_LITE_MICRO_EXPECT_EQ(true, mock_reporter.HasBeenCalled());
mock_reporter.ResetState();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
micro_op_resolver.AddCustom("mock_custom_1", &r)); micro_op_resolver.AddCustom("mock_custom_1", &r));