diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 18a19ab599c..3b0d393fe30 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -36,6 +36,7 @@ cc_library( "comparisons.cc", "concatenation.cc", "dequantize.cc", + "detection_postprocess.cc", "elementwise.cc", "ethosu.cc", "floor.cc", @@ -115,6 +116,7 @@ cc_library( "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/micro:memory_helpers", "//tensorflow/lite/micro:micro_utils", + "@flatbuffers", ] + select({ "//conditions:default": [], ":xtensa_hifimini": [ @@ -166,6 +168,20 @@ tflite_micro_cc_test( ], ) +tflite_micro_cc_test( + name = "detection_postprocess_test", + srcs = [ + "detection_postprocess_test.cc", + ], + deps = [ + ":kernel_runner", + ":micro_ops", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro/testing:micro_test", + "@flatbuffers", + ], +) + tflite_micro_cc_test( name = "fully_connected_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc new file mode 100644 index 00000000000..9146cc096c8 --- /dev/null +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -0,0 +1,868 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#define FLATBUFFERS_LOCALE_INDEPENDENT 0 +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace detection_postprocess { + +/** + * This version of detection_postprocess is specific to TFLite Micro. It + * contains the following differences between the TFLite version: + * + * 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API. + * 2.) Output dimensions - the TFLite version determines output size + * and resizes the output tensor. Micro runtime does not support tensor + * resizing. However if output dimensions are undefined TFLu memory API is + * used to allocate the new dimensions. + */ + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr int kNumCoordBox = 4; +constexpr int kBatchSize = 1; + +constexpr int kNumDetectionsPerClass = 100; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the lower left corner (xmin, ymin) and +// the upper right corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; // Fast Non-Max-Suppression + int detections_per_class; // Regular Non-Max-Suppression + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + bool use_regular_non_max_suppression; + CenterSizeEncoding scale_values; + + // Scratch buffers indexes + int active_candidate_idx; + int decoded_boxes_idx; + int scores_idx; + int score_buffer_idx; + int keep_scores_idx; + int scores_after_regular_non_max_suppression_idx; + int sorted_values_idx; + int keep_indices_idx; + int sorted_indices_idx; + int buffer_idx; + int selected_idx; + + // Cached tensor scale and zero point values for quantized operations + TfLiteQuantizationParams input_box_encodings; + TfLiteQuantizationParams input_class_predictions; + TfLiteQuantizationParams input_anchors; +}; + +TfLiteStatus AllocateOutDimensions(TfLiteContext* context, + TfLiteIntArray** dims, int x, int y = 0, + int z = 0) { + int size = 1; + + size = size * x; + size = (y > 0) ? size * y : size; + size = (z > 0) ? size * z : size; + + *dims = reinterpret_cast(context->AllocatePersistentBuffer( + context, TfLiteIntArrayGetSizeInBytes(size))); + + (*dims)->size = size; + (*dims)->data[0] = x; + if (y > 0) { + (*dims)->data[1] = y; + } + if (z > 0) { + (*dims)->data[2] = z; + } + + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + OpData* op_data = nullptr; + + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + op_data = reinterpret_cast( + context->AllocatePersistentBuffer(context, sizeof(OpData))); + + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + if (m["detections_per_class"].IsNull()) + op_data->detections_per_class = kNumDetectionsPerClass; + else + op_data->detections_per_class = m["detections_per_class"].AsInt32(); + if (m["use_regular_nms"].IsNull()) + op_data->use_regular_non_max_suppression = false; + else + op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool(); + + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = static_cast(node->user_data); + + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + + op_data->input_box_encodings.scale = input_box_encodings->params.scale; + op_data->input_box_encodings.zero_point = + input_box_encodings->params.zero_point; + op_data->input_class_predictions.scale = + input_class_predictions->params.scale; + op_data->input_class_predictions.zero_point = + input_class_predictions->params.zero_point; + op_data->input_anchors.scale = input_anchors->params.scale; + op_data->input_anchors.zero_point = input_anchors->params.zero_point; + + // Scratch tensors + context->RequestScratchBufferInArena(context, num_boxes, + &op_data->active_candidate_idx); + context->RequestScratchBufferInArena(context, + num_boxes * kNumCoordBox * sizeof(float), + &op_data->decoded_boxes_idx); + context->RequestScratchBufferInArena( + context, + input_class_predictions->dims->data[1] * + input_class_predictions->dims->data[2] * sizeof(float), + &op_data->scores_idx); + + // Additional buffers + context->RequestScratchBufferInArena(context, num_boxes * sizeof(float), + &op_data->scores_idx); + context->RequestScratchBufferInArena(context, num_boxes * sizeof(float), + &op_data->keep_scores_idx); + context->RequestScratchBufferInArena( + context, op_data->max_detections * num_boxes * sizeof(float), + &op_data->scores_after_regular_non_max_suppression_idx); + context->RequestScratchBufferInArena( + context, op_data->max_detections * num_boxes * sizeof(float), + &op_data->sorted_values_idx); + context->RequestScratchBufferInArena(context, num_boxes * sizeof(int), + &op_data->keep_indices_idx); + context->RequestScratchBufferInArena( + context, op_data->max_detections * num_boxes * sizeof(int), + &op_data->sorted_indices_idx); + int buffer_size = std::max(num_classes, op_data->max_detections); + context->RequestScratchBufferInArena( + context, buffer_size * num_boxes * sizeof(int), &op_data->buffer_idx); + buffer_size = std::min(num_boxes, op_data->max_detections); + context->RequestScratchBufferInArena( + context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx); + + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + if (detection_boxes->dims->size == 0) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions(context, &detection_boxes->dims, + 1, num_detected_boxes, 4)); + } + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + if (detection_classes->dims->size == 0) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &detection_classes->dims, 1, num_detected_boxes)); + } + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + if (detection_scores->dims->size == 0) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &detection_scores->dims, 1, num_detected_boxes)); + } + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + if (num_detections->dims->size == 0) { + TF_LITE_ENSURE_STATUS( + AllocateOutDimensions(context, &num_detections->dims, 1)); + } + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8_t x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings, + int idx, float quant_zero_point, float quant_scale, + int length_box_encoding, + CenterSizeEncoding* box_centersize) { + const uint8_t* boxes = + tflite::micro::GetTensorData(input_box_encodings) + + length_box_encoding * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + // See definition of the KeyPointBoxCoder at + // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py + // The first four elements are the box coordinates, which is the same as the + // FastRnnBoxCoder at + // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteEvalTensor* tensor) { + const float* tensor_base = tflite::micro::GetTensorData(tensor); + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteEvalTensor* tensor) { + float* tensor_base = tflite::micro::GetTensorData(tensor); + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox); + const TfLiteEvalTensor* input_anchors = + tflite::micro::GetEvalInput(context, node, kInputTensorAnchors); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + CenterSizeEncoding anchor; + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings( + input_box_encodings, idx, + static_cast(op_data->input_box_encodings.zero_point), + static_cast(op_data->input_box_encodings.scale), + input_box_encodings->dims->data[2], &box_centersize); + DequantizeBoxEncodings( + input_anchors, idx, + static_cast(op_data->input_anchors.zero_point), + static_cast(op_data->input_anchors.scale), kNumCoordBox, + &anchor); + break; + // Float + case kTfLiteFloat32: { + // Please see DequantizeBoxEncodings function for the support detail. + const int box_encoding_idx = idx * input_box_encodings->dims->data[2]; + const float* boxes = &(tflite::micro::GetTensorData( + input_box_encodings)[box_encoding_idx]); + box_centersize = *reinterpret_cast(boxes); + anchor = + ReInterpretTensor(input_anchors)[idx]; + break; + } + default: + // Unsupported type. + return kTfLiteError; + } + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + auto& box = reinterpret_cast(decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void DecreasingPartialArgSort2(const float* values, int num_values, + int num_to_sort, int* indices, int* ind) { + std::iota(ind, ind + num_values, 0); + std::partial_sort( + ind, ind + num_to_sort, ind + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); + + std::iota(indices, indices + num_values, 0); + + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +int SelectDetectionsAboveScoreThreshold(const float* values, int size, + const float threshold, + float* keep_values, int* keep_indices) { + int counter = 0; + for (int i = 0; i < size; i++) { + if (values[i] >= threshold) { + keep_values[counter++] = values[i]; + keep_indices[i] = i; + } + } + return counter; +} + +bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = reinterpret_cast(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const float* decoded_boxes, const int i, + const int j) { + auto& box_i = reinterpret_cast(decoded_boxes)[i]; + auto& box_j = reinterpret_cast(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() prunes out the box locations with high overlap +// before selecting the highest scoring boxes (max_detections in number) +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +// Complexity is O(N^2) pairwise comparison between boxes +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const float* scores, int* selected, int* selected_size, + int max_detections) { + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const int num_boxes = input_box_encodings->dims->data[1]; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); + + // threshold scores + int* keep_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->keep_indices_idx)); + float* keep_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->keep_scores_idx)); + int num_scores_kept = SelectDetectionsAboveScoreThreshold( + scores, num_boxes, non_max_suppression_score_threshold, keep_scores, + keep_indices); + int* sorted_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_indices_idx)); + + DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept, + sorted_indices); + + const int num_boxes_kept = num_scores_kept; + const int output_size = std::min(num_boxes_kept, max_detections); + *selected_size = 0; + + int num_active_candidate = num_boxes_kept; + uint8_t* active_box_candidate = reinterpret_cast( + context->GetScratchBuffer(context, op_data->active_candidate_idx)); + + for (int row = 0; row < num_boxes_kept; row++) { + active_box_candidate[row] = 1; + } + for (int i = 0; i < num_boxes_kept; ++i) { + if (num_active_candidate == 0 || *selected_size >= output_size) break; + if (active_box_candidate[i] == 1) { + selected[(*selected_size)++] = keep_indices[sorted_indices[i]]; + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes_kept; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + + return kTfLiteOk; +} + +// This function implements a regular version of Non Maximal Suppression (NMS) +// for multiple classes where +// 1) we do NMS separately for each class across all anchors and +// 2) keep only the highest anchor scores across all classes +// 3) The worst runtime of the regular NMS is O(K*N^2) +// where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); + TfLiteEvalTensor* detection_boxes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput( + context, node, kOutputTensorDetectionClasses); + TfLiteEvalTensor* detection_scores = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores); + TfLiteEvalTensor* num_detections = + tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int num_detections_per_class = op_data->detections_per_class; + const int max_detections = op_data->max_detections; + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + // The row index offset is 1 if background class is included and 0 otherwise. + int label_offset = num_classes_with_background - num_classes; + TF_LITE_ENSURE(context, num_detections_per_class > 0); + + // For each class, perform non-max suppression. + float* class_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->score_buffer_idx)); + int* box_indices_after_regular_non_max_suppression = reinterpret_cast( + context->GetScratchBuffer(context, op_data->buffer_idx)); + float* scores_after_regular_non_max_suppression = + reinterpret_cast(context->GetScratchBuffer( + context, op_data->scores_after_regular_non_max_suppression_idx)); + + int size_of_sorted_indices = 0; + int* sorted_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_indices_idx)); + float* sorted_values = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_values_idx)); + + for (int col = 0; col < num_classes; col++) { + for (int row = 0; row < num_boxes; row++) { + // Get scores of boxes corresponding to all anchors for single class + class_scores[row] = + *(scores + row * num_classes_with_background + col + label_offset); + } + // Perform non-maximal suppression on single class + int selected_size = 0; + int* selected = reinterpret_cast( + context->GetScratchBuffer(context, op_data->selected_idx)); + TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( + context, node, op_data, class_scores, selected, &selected_size, + num_detections_per_class)); + // Add selected indices from non-max suppression of boxes in this class + int output_index = size_of_sorted_indices; + for (int i = 0; i < selected_size; i++) { + int selected_index = selected[i]; + + box_indices_after_regular_non_max_suppression[output_index] = + (selected_index * num_classes_with_background + col + label_offset); + scores_after_regular_non_max_suppression[output_index] = + class_scores[selected_index]; + output_index++; + } + // Sort the max scores among the selected indices + // Get the indices for top scores + int num_indices_to_sort = std::min(output_index, max_detections); + DecreasingPartialArgSort(scores_after_regular_non_max_suppression, + output_index, num_indices_to_sort, sorted_indices); + + // Copy values to temporary vectors + for (int row = 0; row < num_indices_to_sort; row++) { + int temp = sorted_indices[row]; + sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp]; + sorted_values[row] = scores_after_regular_non_max_suppression[temp]; + } + // Copy scores and indices from temporary vectors + for (int row = 0; row < num_indices_to_sort; row++) { + box_indices_after_regular_non_max_suppression[row] = sorted_indices[row]; + scores_after_regular_non_max_suppression[row] = sorted_values[row]; + } + size_of_sorted_indices = num_indices_to_sort; + } + + // Allocate output tensors + for (int output_box_index = 0; output_box_index < max_detections; + output_box_index++) { + if (output_box_index < size_of_sorted_indices) { + const int anchor_index = floor( + box_indices_after_regular_non_max_suppression[output_box_index] / + num_classes_with_background); + const int class_index = + box_indices_after_regular_non_max_suppression[output_box_index] - + anchor_index * num_classes_with_background - label_offset; + const float selected_score = + scores_after_regular_non_max_suppression[output_box_index]; + // detection_boxes + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + ReInterpretTensor(detection_boxes)[output_box_index] = + reinterpret_cast(decoded_boxes)[anchor_index]; + // detection_classes + tflite::micro::GetTensorData(detection_classes)[output_box_index] = + class_index; + // detection_scores + tflite::micro::GetTensorData(detection_scores)[output_box_index] = + selected_score; + } else { + ReInterpretTensor( + detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; + // detection_classes + tflite::micro::GetTensorData(detection_classes)[output_box_index] = + 0.0f; + // detection_scores + tflite::micro::GetTensorData(detection_scores)[output_box_index] = + 0.0f; + } + } + tflite::micro::GetTensorData(num_detections)[0] = + size_of_sorted_indices; + + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); + TfLiteEvalTensor* detection_boxes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes); + + TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput( + context, node, kOutputTensorDetectionClasses); + TfLiteEvalTensor* detection_scores = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores); + TfLiteEvalTensor* num_detections = + tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + // The row index offset is 1 if background class is included and 0 otherwise. + int label_offset = num_classes_with_background - num_classes; + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + float* max_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->score_buffer_idx)); + int* sorted_class_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->buffer_idx)); + + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + + // Perform non-maximal suppression on max scores + int selected_size = 0; + int* selected = reinterpret_cast( + context->GetScratchBuffer(context, op_data->selected_idx)); + TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( + context, node, op_data, max_scores, selected, &selected_size, + op_data->max_detections)); + + // Allocate output tensors + int output_box_index = 0; + + for (int i = 0; i < selected_size; i++) { + int selected_index = selected[i]; + + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + + // detection_boxes + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + ReInterpretTensor(detection_boxes)[box_offset] = + reinterpret_cast(decoded_boxes)[selected_index]; + + // detection_classes + tflite::micro::GetTensorData(detection_classes)[box_offset] = + class_indices[col]; + + // detection_scores + tflite::micro::GetTensorData(detection_scores)[box_offset] = + box_scores[class_indices[col]]; + + output_box_index++; + } + } + + tflite::micro::GetTensorData(num_detections)[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + float* scores, OpData* op_data) { + float quant_zero_point = + static_cast(op_data->input_class_predictions.zero_point); + float quant_scale = + static_cast(op_data->input_class_predictions.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8_t* scores_quant = + tflite::micro::GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1)); + TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes)); + + const float* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + float* temporary_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->scores_idx)); + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores, + op_data); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = tflite::micro::GetTensorData(input_class_predictions); + break; + default: + // Unsupported type. + return kTfLiteError; + } + + if (op_data->use_regular_non_max_suppression) + TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper( + context, node, op_data, scores)); + else + TF_LITE_ENSURE_STATUS( + NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, (kBatchSize == 1)); + auto* op_data = static_cast(node->user_data); + + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data)); + + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data)); + + return kTfLiteOk; +} + +} // namespace detection_postprocess + +TfLiteRegistration Register_DETECTION_POSTPROCESS() { + return {/*init=*/detection_postprocess::Init, + /*free=*/nullptr, + /*prepare=*/detection_postprocess::Prepare, + /*invoke=*/detection_postprocess::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc new file mode 100644 index 00000000000..7bdaa0306e9 --- /dev/null +++ b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc @@ -0,0 +1,496 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +// Common inputs and outputs. + +static const int kInputShape1[] = {3, 1, 6, 4}; +static const int kInputShape2[] = {3, 1, 6, 3}; +static const int kInputShape3[] = {2, 6, 4}; +static const int kOutputShape1[] = {3, 1, 3, 4}; +static const int kOutputShape2[] = {2, 1, 3}; +static const int kOutputShape3[] = {2, 1, 3}; +static const int kOutputShape4[] = {1, 1}; + +// six boxes in center-size encoding +static const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, // box #1 + 0.0, 1.0, 0.0, 0.0, // box #2 + 0.0, -1.0, 0.0, 0.0, // box #3 + 0.0, 0.0, 0.0, 0.0, // box #4 + 0.0, 1.0, 0.0, 0.0, // box #5 + 0.0, 0.0, 0.0, 0.0 // box #6 +}; + +// class scores - two classes with background +static const float kInputData2[] = {0., .9, .8, 0., .75, .72, 0., .6, .5, + 0., .93, .95, 0., .5, .4, 0., .3, .2}; + +// six anchors in center-size encoding +static const float kInputData3[] = { + 0.5, 0.5, 1.0, 1.0, // anchor #1 + 0.5, 0.5, 1.0, 1.0, // anchor #2 + 0.5, 0.5, 1.0, 1.0, // anchor #3 + 0.5, 10.5, 1.0, 1.0, // anchor #4 + 0.5, 10.5, 1.0, 1.0, // anchor #5 + 0.5, 100.5, 1.0, 1.0 // anchor #6 +}; +// Same boxes in box-corner encoding: +// { 0.0, 0.0, 1.0, 1.0, +// 0.0, 0.1, 1.0, 1.1, +// 0.0, -0.1, 1.0, 0.9, +// 0.0, 10.0, 1.0, 11.0, +// 0.0, 10.1, 1.0, 11.1, +// 0.0, 100.0, 1.0, 101.0} + +static const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, + 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}; +static const float kGolden2[] = {1, 0, 0}; +static const float kGolden3[] = {0.95, 0.9, 0.3}; +static const float kGolden4[] = {3.0}; + +void TestDetectionPostprocess( + const int* input_dims_data1, const float* input_data1, + const int* input_dims_data2, const float* input_data2, + const int* input_dims_data3, const float* input_data3, + const int* output_dims_data1, float* output_data1, + const int* output_dims_data2, float* output_data2, + const int* output_dims_data3, float* output_data3, + const int* output_dims_data4, float* output_data4, const float* golden1, + const float* golden2, const float* golden3, const float* golden4, + const float tolerance, bool use_regular_nms, + uint8_t* input_data_quantized1 = nullptr, + uint8_t* input_data_quantized2 = nullptr, + uint8_t* input_data_quantized3 = nullptr, const float input_min1 = 0, + const float input_max1 = 0, const float input_min2 = 0, + const float input_max2 = 0, const float input_min3 = 0, + const float input_max3 = 0) { + TfLiteIntArray* input_dims1 = IntArrayFromInts(input_dims_data1); + TfLiteIntArray* input_dims2 = IntArrayFromInts(input_dims_data2); + TfLiteIntArray* input_dims3 = IntArrayFromInts(input_dims_data3); + TfLiteIntArray* output_dims1 = nullptr; + TfLiteIntArray* output_dims2 = nullptr; + TfLiteIntArray* output_dims3 = nullptr; + TfLiteIntArray* output_dims4 = nullptr; + + const int zero_length_int_array_data[] = {0}; + TfLiteIntArray* zero_length_int_array = + IntArrayFromInts(zero_length_int_array_data); + + output_dims1 = output_dims_data1 == nullptr + ? const_cast(zero_length_int_array) + : IntArrayFromInts(output_dims_data1); + output_dims2 = output_dims_data2 == nullptr + ? const_cast(zero_length_int_array) + : IntArrayFromInts(output_dims_data2); + output_dims3 = output_dims_data3 == nullptr + ? const_cast(zero_length_int_array) + : IntArrayFromInts(output_dims_data3); + output_dims4 = output_dims_data4 == nullptr + ? const_cast(zero_length_int_array) + : IntArrayFromInts(output_dims_data4); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 4; + constexpr int tensors_size = inputs_size + outputs_size; + + TfLiteTensor tensors[tensors_size]; + if (input_min1 != 0 || input_max1 != 0 || input_min2 != 0 || + input_max2 != 0 || input_min3 != 0 || input_max3 != 0) { + const float input_scale1 = ScaleFromMinMax(input_min1, input_max1); + const int input_zero_point1 = + ZeroPointFromMinMax(input_min1, input_max1); + const float input_scale2 = ScaleFromMinMax(input_min2, input_max2); + const int input_zero_point2 = + ZeroPointFromMinMax(input_min2, input_max2); + const float input_scale3 = ScaleFromMinMax(input_min3, input_max3); + const int input_zero_point3 = + ZeroPointFromMinMax(input_min3, input_max3); + + tensors[0] = + CreateQuantizedTensor(input_data1, input_data_quantized1, input_dims1, + input_scale1, input_zero_point1); + tensors[1] = + CreateQuantizedTensor(input_data2, input_data_quantized2, input_dims2, + input_scale2, input_zero_point2); + tensors[2] = + CreateQuantizedTensor(input_data3, input_data_quantized3, input_dims3, + input_scale3, input_zero_point3); + } else { + tensors[0] = CreateFloatTensor(input_data1, input_dims1); + tensors[1] = CreateFloatTensor(input_data2, input_dims2); + tensors[2] = CreateFloatTensor(input_data3, input_dims3); + } + tensors[3] = CreateFloatTensor(output_data1, output_dims1); + tensors[4] = CreateFloatTensor(output_data2, output_dims2); + tensors[5] = CreateFloatTensor(output_data3, output_dims3); + tensors[6] = CreateFloatTensor(output_data4, output_dims4); + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Int("detections_per_class", 1); + fbb.Bool("use_regular_nms", use_regular_nms); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + + const TfLiteRegistration& registration = + tflite::ops::micro::Register_DETECTION_POSTPROCESS(); + TF_LITE_MICRO_EXPECT_NE(nullptr, ®istration); + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {4, 3, 4, 5, 6}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, nullptr, micro_test::reporter); + + const char* init_data = reinterpret_cast(fbb.GetBuffer().data()); + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, runner.InitAndPrepare(init_data, fbb.GetBuffer().size())); + + // Output dimensions should not be undefined after Prepare + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[3].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[4].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[5].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[6].dims); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + const int output_elements_count1 = tensors[3].dims->size; + const int output_elements_count2 = tensors[4].dims->size; + const int output_elements_count3 = tensors[5].dims->size; + const int output_elements_count4 = tensors[6].dims->size; + + for (int i = 0; i < output_elements_count1; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden1[i], output_data1[i], tolerance); + } + for (int i = 0; i < output_elements_count2; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden2[i], output_data2[i], tolerance); + } + for (int i = 0; i < output_elements_count3; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden3[i], output_data3[i], tolerance); + } + for (int i = 0; i < output_elements_count4; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden4[i], output_data4[i], tolerance); + } +} +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1, + tflite::testing::kGolden2, tflite::testing::kGolden3, + tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const int kInputElements1 = tflite::testing::kInputShape1[1] * + tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * + tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = + tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1 + 10]; + uint8_t input_data_quantized2[kInputElements2 + 10]; + uint8_t input_data_quantized3[kInputElements3 + 10]; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1, + tflite::testing::kGolden2, tflite::testing::kGolden3, + tflite::testing::kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ false, input_data_quantized1, + input_data_quantized2, input_data_quantized3, + /* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatRegularNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, + 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, kGolden1, + tflite::testing::kGolden2, kGolden3, kGolden4, + /* tolerance */ 1e-1, /* Use regular NMS: */ true); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedRegularNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const int kInputElements1 = tflite::testing::kInputShape1[1] * + tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * + tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = + tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1 + 10]; + uint8_t input_data_quantized2[kInputElements2 + 10]; + uint8_t input_data_quantized3[kInputElements3 + 10]; + + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, + 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, kGolden1, + tflite::testing::kGolden2, kGolden3, kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ true, input_data_quantized1, + input_data_quantized2, input_data_quantized3, + /* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST( + DetectionPostprocessFloatFastNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + const int kInputShape2[] = {3, 1, 6, 2}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, + .93, .95, .5, .4, .3, .2}; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + kInputShape1, kInputData1, kInputShape2, kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1, + tflite::testing::kGolden2, tflite::testing::kGolden3, + tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST( + DetectionPostprocessFloatRegularNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape2[] = {3, 1, 6, 2}; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, + .93, .95, .5, .4, .3, .2}; + + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, + 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, kInputShape2, + kInputData2, tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, kGolden1, + tflite::testing::kGolden2, kGolden3, kGolden4, + /* tolerance */ 1e-1, /* Use regular NMS: */ true); +} + +TF_LITE_MICRO_TEST( + DetectionPostprocessFloatFastNMSWithBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + kInputShape1, kInputData1, tflite::testing::kInputShape2, + tflite::testing::kInputData2, tflite::testing::kInputShape3, + tflite::testing::kInputData3, tflite::testing::kOutputShape1, + output_data1, tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1, + tflite::testing::kGolden2, tflite::testing::kGolden3, + tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST( + DetectionPostprocessQuantizedFastNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + const int kInputShape2[] = {3, 1, 6, 2}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, + .93, .95, .5, .4, .3, .2}; + + const int kInputElements1 = tflite::testing::kInputShape1[1] * + tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * + tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = + tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1 + 10]; + uint8_t input_data_quantized2[kInputElements2 + 10]; + uint8_t input_data_quantized3[kInputElements3 + 10]; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + kInputShape1, kInputData1, kInputShape2, kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1, + tflite::testing::kGolden2, tflite::testing::kGolden3, + tflite::testing::kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ false, input_data_quantized1, + input_data_quantized2, input_data_quantized3, + /* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSUndefinedOutputDimensions) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess( + tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, nullptr, + output_data1, nullptr, output_data2, nullptr, output_data3, nullptr, + output_data4, tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc index cef6c01cf45..e4ff277cc09 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.cc +++ b/tensorflow/lite/micro/kernels/kernel_runner.cc @@ -23,16 +23,16 @@ constexpr size_t kBufferAlignment = 16; } // namespace // TODO(b/161841696): Consider moving away from global arena buffers: -constexpr int KernelRunner::kNumScratchBuffers_; -constexpr int KernelRunner::kKernelRunnerBufferSize_; -uint8_t KernelRunner::kKernelRunnerBuffer_[]; +constexpr int KernelRunner::kNumScratchBuffers; +constexpr int KernelRunner::kKernelRunnerBufferSize; +uint8_t KernelRunner::kernel_runner_buffer_[]; KernelRunner::KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, TfLiteIntArray* outputs, void* builtin_data, ErrorReporter* error_reporter) : allocator_(SimpleMemoryAllocator::Create( - error_reporter, kKernelRunnerBuffer_, kKernelRunnerBufferSize_)), + error_reporter, kernel_runner_buffer_, kKernelRunnerBufferSize)), registration_(registration), tensors_(tensors), error_reporter_(error_reporter) { @@ -52,9 +52,10 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration, node_.builtin_data = builtin_data; } -TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) { +TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data, + size_t length) { if (registration_.init) { - node_.user_data = registration_.init(&context_, init_data, /*length=*/0); + node_.user_data = registration_.init(&context_, init_data, length); } if (registration_.prepare) { TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_)); @@ -117,11 +118,11 @@ TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context, KernelRunner* runner = reinterpret_cast(context->impl_); TFLITE_DCHECK(runner != nullptr); - if (runner->scratch_buffer_count_ == kNumScratchBuffers_) { + if (runner->scratch_buffer_count_ == kNumScratchBuffers) { TF_LITE_REPORT_ERROR( runner->error_reporter_, "Exceeded the maximum number of scratch tensors allowed (%d).", - kNumScratchBuffers_); + kNumScratchBuffers); return kTfLiteError; } @@ -142,7 +143,7 @@ void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) { KernelRunner* runner = reinterpret_cast(context->impl_); TFLITE_DCHECK(runner != nullptr); - TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_); + TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers); if (buffer_index >= runner->scratch_buffer_count_) { return nullptr; } diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h index 45d107e7a37..d88f00dfbd3 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/tensorflow/lite/micro/kernels/kernel_runner.h @@ -39,7 +39,8 @@ class KernelRunner { // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any // exceptions will be reported through the error_reporter and returned as a // status code here. - TfLiteStatus InitAndPrepare(const char* init_data = nullptr); + TfLiteStatus InitAndPrepare(const char* init_data = nullptr, + size_t length = 0); // Calls init, prepare, and invoke on a given TfLiteRegistration pointer. // After successful invoke, results will be available in the output tensor as @@ -60,10 +61,10 @@ class KernelRunner { ...); private: - static constexpr int kNumScratchBuffers_ = 5; + static constexpr int kNumScratchBuffers = 12; - static constexpr int kKernelRunnerBufferSize_ = 10000; - static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_]; + static constexpr int kKernelRunnerBufferSize = 10000; + static uint8_t kernel_runner_buffer_[kKernelRunnerBufferSize]; SimpleMemoryAllocator* allocator_ = nullptr; const TfLiteRegistration& registration_; @@ -74,7 +75,7 @@ class KernelRunner { TfLiteNode node_ = {}; int scratch_buffer_count_ = 0; - uint8_t* scratch_buffers_[kNumScratchBuffers_]; + uint8_t* scratch_buffers_[kNumScratchBuffers]; }; } // namespace micro diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index f86b28a9ff2..7b76d64a816 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -42,6 +42,7 @@ TfLiteRegistration Register_CONCATENATION(); TfLiteRegistration Register_COS(); TfLiteRegistration Register_DEPTHWISE_CONV_2D(); TfLiteRegistration Register_DEQUANTIZE(); +TfLiteRegistration Register_DETECTION_POSTPROCESS(); TfLiteRegistration Register_EQUAL(); TfLiteRegistration Register_FLOOR(); TfLiteRegistration Register_FULLY_CONNECTED(); diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 26575a4d98d..ae3984566a6 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -884,7 +884,11 @@ TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, TfLiteTensor result = CreateTensor(dims, is_variable); result.type = kTfLiteFloat32; result.data.f = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(float); + if (dims == nullptr) { + result.bytes = 0; + } else { + result.bytes = ElementCount(*dims) * sizeof(float); + } return result; } diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 8c16ed70bbc..1bc2b4427d3 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -36,7 +36,7 @@ constexpr size_t kBufferAlignment = 16; // We store the pointer to the ith scratch buffer to implement the Request/Get // ScratchBuffer API for the tests. scratch_buffers_[i] will be the ith scratch // buffer and will still be allocated from within raw_arena_. -constexpr int kNumScratchBuffers = 5; +constexpr int kNumScratchBuffers = 12; uint8_t* scratch_buffers_[kNumScratchBuffers]; int scratch_buffer_count_ = 0; diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index c44d52a3cc7..09b90c13db9 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -281,6 +281,8 @@ third_party/gemmlowp/LICENSE \ third_party/flatbuffers/include/flatbuffers/base.h \ third_party/flatbuffers/include/flatbuffers/stl_emulation.h \ third_party/flatbuffers/include/flatbuffers/flatbuffers.h \ +third_party/flatbuffers/include/flatbuffers/flexbuffers.h \ +third_party/flatbuffers/include/flatbuffers/util.h \ third_party/flatbuffers/LICENSE.txt \ third_party/ruy/ruy/profiler/instrumentation.h diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 06fdddbbeb9..bbe03eed00b 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -65,7 +65,8 @@ ifeq ($(TARGET), bluepill) tensorflow/lite/micro/micro_allocator_test.cc \ tensorflow/lite/micro/memory_helpers_test.cc \ tensorflow/lite/micro/memory_arena_threshold_test.cc \ - tensorflow/lite/micro/kernels/circular_buffer_test.cc + tensorflow/lite/micro/kernels/circular_buffer_test.cc \ + tensorflow/lite/micro/kernels/detection_postprocess_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) EXCLUDED_EXAMPLE_TESTS := \ diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc index e9ee7296999..ec5396c1030 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc @@ -71,8 +71,8 @@ ifeq ($(TARGET), stm32f4) tensorflow/lite/micro/recording_micro_allocator_test.cc \ tensorflow/lite/micro/kernels/circular_buffer_test.cc \ tensorflow/lite/micro/kernels/conv_test.cc \ - tensorflow/lite/micro/kernels/fully_connected_test.cc - + tensorflow/lite/micro/kernels/fully_connected_test.cc \ + tensorflow/lite/micro/kernels/detection_postprocess_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) EXCLUDED_EXAMPLE_TESTS := \