diff --git a/tensorflow/lite/micro/all_ops_resolver.cc b/tensorflow/lite/micro/all_ops_resolver.cc index b0021a2e771..e728a95360a 100644 --- a/tensorflow/lite/micro/all_ops_resolver.cc +++ b/tensorflow/lite/micro/all_ops_resolver.cc @@ -26,74 +26,60 @@ const char* GetString_ETHOSU(); AllOpsResolver::AllOpsResolver() { // Please keep this list of Builtin Operators in alphabetical order. - AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS()); - AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD()); - AddBuiltin(BuiltinOperator_ARG_MAX, tflite::ops::micro::Register_ARG_MAX()); - AddBuiltin(BuiltinOperator_ARG_MIN, tflite::ops::micro::Register_ARG_MIN()); - AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, - tflite::ops::micro::Register_AVERAGE_POOL_2D()); - AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL()); - AddBuiltin(BuiltinOperator_CONCATENATION, - tflite::ops::micro::Register_CONCATENATION()); - AddBuiltin(BuiltinOperator_CONV_2D, tflite::ops::micro::Register_CONV_2D()); - AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS()); - AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - AddBuiltin(BuiltinOperator_DEQUANTIZE, - tflite::ops::micro::Register_DEQUANTIZE()); - AddBuiltin(BuiltinOperator_EQUAL, tflite::ops::micro::Register_EQUAL()); - AddBuiltin(BuiltinOperator_FLOOR, tflite::ops::micro::Register_FLOOR()); - AddBuiltin(BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - AddBuiltin(BuiltinOperator_GREATER, tflite::ops::micro::Register_GREATER()); - AddBuiltin(BuiltinOperator_GREATER_EQUAL, - tflite::ops::micro::Register_GREATER_EQUAL()); - AddBuiltin(BuiltinOperator_L2_NORMALIZATION, - tflite::ops::micro::Register_L2_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS()); - AddBuiltin(BuiltinOperator_LESS_EQUAL, - tflite::ops::micro::Register_LESS_EQUAL()); - AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG()); - AddBuiltin(BuiltinOperator_LOGICAL_AND, - tflite::ops::micro::Register_LOGICAL_AND()); - AddBuiltin(BuiltinOperator_LOGICAL_NOT, - tflite::ops::micro::Register_LOGICAL_NOT()); - AddBuiltin(BuiltinOperator_LOGICAL_OR, - tflite::ops::micro::Register_LOGICAL_OR()); - AddBuiltin(BuiltinOperator_LOGISTIC, tflite::ops::micro::Register_LOGISTIC()); - AddBuiltin(BuiltinOperator_MAX_POOL_2D, - tflite::ops::micro::Register_MAX_POOL_2D()); - AddBuiltin(BuiltinOperator_MAXIMUM, tflite::ops::micro::Register_MAXIMUM()); - AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN()); - AddBuiltin(BuiltinOperator_MINIMUM, tflite::ops::micro::Register_MINIMUM()); - AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL()); - AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG()); - AddBuiltin(BuiltinOperator_NOT_EQUAL, - tflite::ops::micro::Register_NOT_EQUAL()); - AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK()); - AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD()); - AddBuiltin(BuiltinOperator_PADV2, tflite::ops::micro::Register_PADV2()); - AddBuiltin(BuiltinOperator_PRELU, tflite::ops::micro::Register_PRELU()); - AddBuiltin(BuiltinOperator_QUANTIZE, tflite::ops::micro::Register_QUANTIZE()); - AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU()); - AddBuiltin(BuiltinOperator_RELU6, tflite::ops::micro::Register_RELU6()); - AddBuiltin(BuiltinOperator_RESHAPE, tflite::ops::micro::Register_RESHAPE()); - AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR()); - AddBuiltin(BuiltinOperator_ROUND, tflite::ops::micro::Register_ROUND()); - AddBuiltin(BuiltinOperator_RSQRT, tflite::ops::micro::Register_RSQRT()); - AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN()); - AddBuiltin(BuiltinOperator_SOFTMAX, tflite::ops::micro::Register_SOFTMAX()); - AddBuiltin(BuiltinOperator_SPLIT, tflite::ops::micro::Register_SPLIT()); - AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT()); - AddBuiltin(BuiltinOperator_SQUARE, tflite::ops::micro::Register_SQUARE()); - AddBuiltin(BuiltinOperator_STRIDED_SLICE, - tflite::ops::micro::Register_STRIDED_SLICE()); - AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB()); - AddBuiltin(BuiltinOperator_SVDF, tflite::ops::micro::Register_SVDF()); - AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH()); - AddBuiltin(BuiltinOperator_UNPACK, tflite::ops::micro::Register_UNPACK()); + AddAbs(); + AddAdd(); + AddArgMax(); + AddArgMin(); + AddAveragePool2D(); + AddCeil(); + AddConcatenation(); + AddConv2D(); + AddCos(); + AddDepthwiseConv2D(); + AddDequantize(); + AddEqual(); + AddFloor(); + AddFullyConnected(); + AddGreater(); + AddGreaterEqual(); + AddL2Normalization(); + AddLess(); + AddLessEqual(); + AddLog(); + AddLogicalAnd(); + AddLogicalNot(); + AddLogicalOr(); + AddLogistic(); + AddMaximum(); + AddMaxPool2D(); + AddMean(); + AddMinimum(); + AddMul(); + AddNeg(); + AddNotEqual(); + AddPack(); + AddPad(); + AddPadV2(); + AddPrelu(); + AddQuantize(); + AddRelu(); + AddRelu6(); + AddReshape(); + AddResizeNearestNeighbor(); + AddRound(); + AddRsqrt(); + AddSin(); + AddSoftmax(); + AddSplit(); + AddSqrt(); + AddSquare(); + AddStridedSlice(); + AddSub(); + AddSvdf(); + AddTanh(); + AddUnpack(); + // TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver. TfLiteRegistration* registration = tflite::ops::micro::custom::Register_ETHOSU(); if (registration) { diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc index ac4de118834..5ad2fb2acbe 100644 --- a/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc +++ b/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc @@ -44,14 +44,10 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) { tflite::MicroMutableOpResolver<4> micro_op_resolver; - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, - tflite::ops::micro::Register_MAX_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddMaxPool2D(); + micro_op_resolver.AddFullyConnected(); + micro_op_resolver.AddSoftmax(); const int tensor_arena_size = 50 * 1024; uint8_t tensor_arena[tensor_arena_size]; diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc index becdbdf1bd7..fcf7b41b827 100644 --- a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc +++ b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc @@ -58,14 +58,10 @@ int main(int argc, char** argv) { tflite::MicroMutableOpResolver<4> micro_op_resolver; - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, - tflite::ops::micro::Register_MAX_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddFullyConnected(); + micro_op_resolver.AddMaxPool2D(); + micro_op_resolver.AddSoftmax(); constexpr int tensor_arena_size = 50 * 1024; uint8_t tensor_arena[tensor_arena_size]; diff --git a/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc b/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc index 88bfad860e2..fb75afee309 100644 --- a/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc +++ b/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc @@ -47,17 +47,11 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) { // incur some penalty in code space for op implementations that are not // needed by this graph. static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, - tflite::ops::micro::Register_MAX_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); + micro_op_resolver.AddFullyConnected(); + micro_op_resolver.AddMaxPool2D(); + micro_op_resolver.AddSoftmax(); // Create an area of memory to use for input, output, and intermediate arrays. // Finding the minimum value for your model may require some trial and error. diff --git a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc index 26c2eb44747..8defeaad866 100644 --- a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc +++ b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc @@ -66,17 +66,11 @@ void setup() { // incur some penalty in code space for op implementations that are not // needed by this graph. static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_MAX_POOL_2D, - tflite::ops::micro::Register_MAX_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); + micro_op_resolver.AddFullyConnected(); + micro_op_resolver.AddMaxPool2D(); + micro_op_resolver.AddSoftmax(); // Build an interpreter to run the model with. static tflite::MicroInterpreter static_interpreter( diff --git a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc index 30c5022b2d6..d09c4c7af06 100644 --- a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc +++ b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc @@ -75,24 +75,16 @@ void setup() { // tflite::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter); - if (micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()) != kTfLiteOk) { + if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) { return; } - if (micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()) != kTfLiteOk) { + if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) { return; } - if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()) != - kTfLiteOk) { + if (micro_op_resolver.AddSoftmax() != kTfLiteOk) { return; } - if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, - tflite::ops::micro::Register_RESHAPE()) != - kTfLiteOk) { + if (micro_op_resolver.AddReshape() != kTfLiteOk) { return; } diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc index 2c442f955cc..0f6a2afd527 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc @@ -49,15 +49,10 @@ TF_LITE_MICRO_TEST(TestInvoke) { // // tflite::AllOpsResolver resolver; tflite::MicroMutableOpResolver<4> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, - tflite::ops::micro::Register_RESHAPE()); + micro_op_resolver.AddDepthwiseConv2D(); + micro_op_resolver.AddFullyConnected(); + micro_op_resolver.AddReshape(); + micro_op_resolver.AddSoftmax(); // Create an area of memory to use for input, output, and intermediate arrays. const int tensor_arena_size = 10 * 1024; diff --git a/tensorflow/lite/micro/examples/person_detection/main_functions.cc b/tensorflow/lite/micro/examples/person_detection/main_functions.cc index aa4d83a3334..d7e9f6826c4 100644 --- a/tensorflow/lite/micro/examples/person_detection/main_functions.cc +++ b/tensorflow/lite/micro/examples/person_detection/main_functions.cc @@ -66,13 +66,9 @@ void setup() { // tflite::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) static tflite::MicroMutableOpResolver<3> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, - tflite::ops::micro::Register_AVERAGE_POOL_2D()); + micro_op_resolver.AddAveragePool2D(); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); // Build an interpreter to run the model with. static tflite::MicroInterpreter static_interpreter( diff --git a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc index bc53a8410da..7e706d49fcc 100644 --- a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc +++ b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc @@ -57,13 +57,9 @@ TF_LITE_MICRO_TEST(TestInvoke) { // // tflite::AllOpsResolver resolver; tflite::MicroMutableOpResolver<3> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, - tflite::ops::micro::Register_AVERAGE_POOL_2D()); + micro_op_resolver.AddAveragePool2D(); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); // Build an interpreter to run the model with. tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena, diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc b/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc index ac47e36ff8f..09a9cb2c6c4 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc +++ b/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc @@ -73,17 +73,11 @@ void setup() { // tflite::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) static tflite::MicroMutableOpResolver<5> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, - tflite::ops::micro::Register_AVERAGE_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, - tflite::ops::micro::Register_RESHAPE()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddAveragePool2D(); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); + micro_op_resolver.AddReshape(); + micro_op_resolver.AddSoftmax(); // Build an interpreter to run the model with. // NOLINTNEXTLINE(runtime-global-variables) diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc b/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc index ddec8951596..270a427b1df 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc +++ b/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc @@ -53,17 +53,11 @@ TF_LITE_MICRO_TEST(TestInvoke) { // incur some penalty in code space for op implementations that are not // needed by this graph. tflite::MicroMutableOpResolver<5> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, - tflite::ops::micro::Register_AVERAGE_POOL_2D()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, - tflite::ops::micro::Register_RESHAPE()); - micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + micro_op_resolver.AddAveragePool2D(); + micro_op_resolver.AddConv2D(); + micro_op_resolver.AddDepthwiseConv2D(); + micro_op_resolver.AddReshape(); + micro_op_resolver.AddSoftmax(); // Build an interpreter to run the model with. tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena, diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index b9ce2bb4bba..1b76f440a61 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -104,39 +104,72 @@ class MicroMutableOpResolver : public MicroOpResolver { return kTfLiteOk; } - // Registers a Builtin Operator with the MicroOpResolver. - // - // Only the first call for a given BuiltinOperator enum will be successful. - // i.e. if this function is called again for a previously added - // BuiltinOperator, the MicroOpResolver will be unchanged and this function - // will return kTfLiteError. - // - // TODO(b/149408647): remove this API once the BuiltinOperator specific Add - // functions are fully implemented. - TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, - TfLiteRegistration* registration) { - TFLITE_DCHECK(registration != nullptr); - // For code that is not switched over to the new selective registration of - // the parse function, we pass in ParseOpData. This allows for backwards - // compatibility. - return AddBuiltin(op, *registration, ParseOpData); - } - // The Add* functions below add the various Builtin operators to the // MicroMutableOpResolver object. - // - // This API is currently experimental (and only supported for a small subset - // of operators). It will soon be preferred over the AddBuiltin function for - // the following reason: - // * If all calls to AddBuiltin for an application use this API, the code - // size will be smaller by 5-8K (compared to the using the AddBuiltin - // override). + + TfLiteStatus AddAbs() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_ABS, *tflite::ops::micro::Register_ABS(), + ParseOpData); + } + + TfLiteStatus AddAdd() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_ADD, *tflite::ops::micro::Register_ADD(), + ParseOpData); + } + + TfLiteStatus AddArgMax() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_ARG_MAX, + *tflite::ops::micro::Register_ARG_MAX(), ParseOpData); + } + + TfLiteStatus AddArgMin() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_ARG_MIN, + *tflite::ops::micro::Register_ARG_MIN(), ParseOpData); + } + + TfLiteStatus AddAveragePool2D() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, + *tflite::ops::micro::Register_AVERAGE_POOL_2D(), + ParseOpData); + } + + TfLiteStatus AddCeil() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_CEIL, + *tflite::ops::micro::Register_CEIL(), ParseOpData); + } + + TfLiteStatus AddConcatenation() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_CONCATENATION, + *tflite::ops::micro::Register_CONCATENATION(), + ParseOpData); + } TfLiteStatus AddConv2D() { return AddBuiltin(BuiltinOperator_CONV_2D, *tflite::ops::micro::Register_CONV_2D(), ParseConv2D); } + TfLiteStatus AddCos() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(), + ParseOpData); + } + TfLiteStatus AddDepthwiseConv2D() { return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, *tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), @@ -149,12 +182,91 @@ class MicroMutableOpResolver : public MicroOpResolver { ParseDequantize); } + TfLiteStatus AddEqual() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_EQUAL, + *tflite::ops::micro::Register_EQUAL(), ParseOpData); + } + + TfLiteStatus AddFloor() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_FLOOR, + *tflite::ops::micro::Register_FLOOR(), ParseOpData); + } + TfLiteStatus AddFullyConnected() { return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, *tflite::ops::micro::Register_FULLY_CONNECTED(), ParseFullyConnected); } + TfLiteStatus AddGreater() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_GREATER, + *tflite::ops::micro::Register_GREATER(), ParseOpData); + } + + TfLiteStatus AddGreaterEqual() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_GREATER_EQUAL, + *tflite::ops::micro::Register_GREATER_EQUAL(), + ParseOpData); + } + + TfLiteStatus AddL2Normalization() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_L2_NORMALIZATION, + *tflite::ops::micro::Register_L2_NORMALIZATION(), + ParseOpData); + } + + TfLiteStatus AddLess() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LESS, + *tflite::ops::micro::Register_LESS(), ParseOpData); + } + + TfLiteStatus AddLessEqual() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LESS_EQUAL, + *tflite::ops::micro::Register_LESS_EQUAL(), ParseOpData); + } + + TfLiteStatus AddLog() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LOG, *tflite::ops::micro::Register_LOG(), + ParseOpData); + } + + TfLiteStatus AddLogicalAnd() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LOGICAL_AND, + *tflite::ops::micro::Register_LOGICAL_AND(), ParseOpData); + } + + TfLiteStatus AddLogicalNot() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LOGICAL_NOT, + *tflite::ops::micro::Register_LOGICAL_NOT(), ParseOpData); + } + + TfLiteStatus AddLogicalOr() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_LOGICAL_OR, + *tflite::ops::micro::Register_LOGICAL_OR(), ParseOpData); + } + TfLiteStatus AddLogistic() { // TODO(b/149408647): Replace ParseOpData with the operator specific parse // function. @@ -162,26 +274,196 @@ class MicroMutableOpResolver : public MicroOpResolver { *tflite::ops::micro::Register_LOGISTIC(), ParseOpData); } + TfLiteStatus AddMaximum() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_MAXIMUM, + *tflite::ops::micro::Register_MAXIMUM(), ParseOpData); + } + + TfLiteStatus AddMaxPool2D() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_MAX_POOL_2D, + *tflite::ops::micro::Register_MAX_POOL_2D(), ParseOpData); + } + + TfLiteStatus AddMean() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_MEAN, + *tflite::ops::micro::Register_MEAN(), ParseOpData); + } + + TfLiteStatus AddMinimum() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_MINIMUM, + *tflite::ops::micro::Register_MINIMUM(), ParseOpData); + } + + TfLiteStatus AddMul() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_MUL, *tflite::ops::micro::Register_MUL(), + ParseOpData); + } + + TfLiteStatus AddNeg() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_NEG, *tflite::ops::micro::Register_NEG(), + ParseOpData); + } + + TfLiteStatus AddNotEqual() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_NOT_EQUAL, + *tflite::ops::micro::Register_NOT_EQUAL(), ParseOpData); + } + + TfLiteStatus AddPack() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_PACK, + *tflite::ops::micro::Register_PACK(), ParseOpData); + } + + TfLiteStatus AddPad() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_PAD, *tflite::ops::micro::Register_PAD(), + ParseOpData); + } + + TfLiteStatus AddPadV2() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_PADV2, + *tflite::ops::micro::Register_PADV2(), ParseOpData); + } + + TfLiteStatus AddPrelu() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_PRELU, + *tflite::ops::micro::Register_PRELU(), ParseOpData); + } + TfLiteStatus AddQuantize() { return AddBuiltin(BuiltinOperator_QUANTIZE, *tflite::ops::micro::Register_QUANTIZE(), ParseQuantize); } + TfLiteStatus AddRelu() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_RELU, + *tflite::ops::micro::Register_RELU(), ParseOpData); + } + + TfLiteStatus AddRelu6() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_RELU6, + *tflite::ops::micro::Register_RELU6(), ParseOpData); + } + TfLiteStatus AddReshape() { return AddBuiltin(BuiltinOperator_RESHAPE, *tflite::ops::micro::Register_RESHAPE(), ParseReshape); } + TfLiteStatus AddResizeNearestNeighbor() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + *tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(), + ParseOpData); + } + + TfLiteStatus AddRound() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_ROUND, + *tflite::ops::micro::Register_ROUND(), ParseOpData); + } + + TfLiteStatus AddRsqrt() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_RSQRT, + *tflite::ops::micro::Register_RSQRT(), ParseOpData); + } + + TfLiteStatus AddSin() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_SIN, *tflite::ops::micro::Register_SIN(), + ParseOpData); + } + TfLiteStatus AddSoftmax() { return AddBuiltin(BuiltinOperator_SOFTMAX, *tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax); } + TfLiteStatus AddSplit() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_SPLIT, + *tflite::ops::micro::Register_SPLIT(), ParseOpData); + } + + TfLiteStatus AddSqrt() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_SQRT, + *tflite::ops::micro::Register_SQRT(), ParseOpData); + } + + TfLiteStatus AddSquare() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_SQUARE, + *tflite::ops::micro::Register_SQUARE(), ParseOpData); + } + + TfLiteStatus AddStridedSlice() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_STRIDED_SLICE, + *tflite::ops::micro::Register_STRIDED_SLICE(), + ParseOpData); + } + + TfLiteStatus AddSub() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_SUB, *tflite::ops::micro::Register_SUB(), + ParseOpData); + } + TfLiteStatus AddSvdf() { return AddBuiltin(BuiltinOperator_SVDF, *tflite::ops::micro::Register_SVDF(), ParseSvdf); } + TfLiteStatus AddTanh() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_TANH, + *tflite::ops::micro::Register_TANH(), ParseOpData); + } + + TfLiteStatus AddUnpack() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function. + return AddBuiltin(BuiltinOperator_UNPACK, + *tflite::ops::micro::Register_UNPACK(), ParseOpData); + } + unsigned int GetRegistrationLength() { return registrations_len_; } private: diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc index ff5dfdf3a9a..fe9c8de5959 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc +++ b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc @@ -68,14 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) { static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; - MicroMutableOpResolver<2> micro_op_resolver; - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r)); - - // Only one AddBuiltin per operator should return kTfLiteOk. - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r)); - + MicroMutableOpResolver<1> micro_op_resolver; TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddCustom("mock_custom", &r)); @@ -85,16 +78,10 @@ TF_LITE_MICRO_TEST(TestOperations) { tflite::MicroOpResolver* resolver = µ_op_resolver; + TF_LITE_MICRO_EXPECT_EQ(1, micro_op_resolver.GetRegistrationLength()); + const TfLiteRegistration* registration = - resolver->FindOp(BuiltinOperator_CONV_2D); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); - - TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength()); - - registration = resolver->FindOp(BuiltinOperator_RELU); + resolver->FindOp(BuiltinOperator_RELU); TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); registration = resolver->FindOp("mock_custom"); @@ -116,12 +103,7 @@ TF_LITE_MICRO_TEST(TestErrorReporting) { tflite::MockPrepare, tflite::MockInvoke}; tflite::MockErrorReporter mock_reporter; - MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); - TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); - mock_reporter.ResetState(); - - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r)); + MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); mock_reporter.ResetState(); @@ -132,10 +114,7 @@ TF_LITE_MICRO_TEST(TestErrorReporting) { // Attempting to Add more operators than the class template parameter for // MicroMutableOpResolver should result in errors. - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, micro_op_resolver.AddBuiltin(BuiltinOperator_RELU, &r)); - TF_LITE_MICRO_EXPECT_EQ(true, mock_reporter.HasBeenCalled()); - mock_reporter.ResetState(); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddRelu()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddCustom("mock_custom_1", &r));