From 0de8f08754f333a0a9b06dc51b8cbbf9e9beb895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Wed, 25 Mar 2020 14:04:00 +0100 Subject: [PATCH 1/9] TFLu: Port TFL detection postprocess operator Change-Id: I9edffb31a3b1485ebc43e2298762ad067bd26dda --- tensorflow/lite/micro/all_ops_resolver.cc | 3 + tensorflow/lite/micro/kernels/BUILD | 15 + .../micro/kernels/detection_postprocess.cc | 788 ++++++++++++++++++ .../kernels/detection_postprocess_test.cc | 469 +++++++++++ tensorflow/lite/micro/test_helpers.cc | 6 +- 5 files changed, 1280 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/micro/kernels/detection_postprocess.cc create mode 100644 tensorflow/lite/micro/kernels/detection_postprocess_test.cc diff --git a/tensorflow/lite/micro/all_ops_resolver.cc b/tensorflow/lite/micro/all_ops_resolver.cc index ff461cb947e..594f69eb5fb 100644 --- a/tensorflow/lite/micro/all_ops_resolver.cc +++ b/tensorflow/lite/micro/all_ops_resolver.cc @@ -18,6 +18,7 @@ namespace tflite { namespace ops { namespace micro { namespace custom { +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); TfLiteRegistration* Register_ETHOSU(); const char* GetString_ETHOSU(); } // namespace custom @@ -81,6 +82,8 @@ AllOpsResolver::AllOpsResolver() { AddUnpack(); // TODO(b/159644355): Figure out if custom Ops belong in AllOpsResolver. + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::micro::custom::Register_DETECTION_POSTPROCESS()); TfLiteRegistration* registration = tflite::ops::micro::custom::Register_ETHOSU(); if (registration) { diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index d88bf91688c..488c6cfe769 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -36,6 +36,7 @@ cc_library( "comparisons.cc", "concatenation.cc", "dequantize.cc", + "detection_postprocess.cc", "elementwise.cc", "ethosu.cc", "floor.cc", @@ -165,6 +166,20 @@ tflite_micro_cc_test( ], ) +tflite_micro_cc_test( + name = "detection_postprocess_test", + srcs = [ + "detection_postprocess_test.cc", + ], + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro/testing:micro_test", + "@flatbuffers", + ], +) + tflite_micro_cc_test( name = "fully_connected_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc new file mode 100644 index 00000000000..5569adce8b3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -0,0 +1,788 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#define FLATBUFFERS_LOCALE_INDEPENDENT 0 +#include "flatbuffers/flexbuffers.h" // from @flatbuffers + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/micro_utils.h" + + +namespace tflite { +namespace ops { +namespace micro { +namespace custom { +namespace detection_postprocess { + +/** + * This version of detection_postprocess is specific to TFLite Micro. It contains the following + * differences between the TFLite version: + * + * 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API. + * 2.) Output dimensions - the TFLite version determines output size + * and resizes the output tensor. Micro runtime does not support tensor + * resizing. However if output dimensions are undefined TFLu memory API is + * used to allocate the new dimensions. + */ + + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr int kNumCoordBox = 4; +constexpr int kBatchSize = 1; + +constexpr int kNumDetectionsPerClass = 100; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the lower left corner (xmin, ymin) and +// the upper right corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; // Fast Non-Max-Suppression + int detections_per_class; // Regular Non-Max-Suppression + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + bool use_regular_non_max_suppression; + CenterSizeEncoding scale_values; + + // Scratch buffers + int active_candidate_idx; + int decoded_boxes_idx; + int scores_idx; + uint8_t* active_box_candidate; + float* decoded_boxes; + float* scores; +}; + +TfLiteStatus AllocateOutDimensions(TfLiteContext* context, TfLiteIntArray** dims, + int x, int y=0, int z=0) { + int size = 1; + + size = size * x; + size = (y > 0) ? size * y : size; + size = (z > 0) ? size * z : size; + + TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer( + context, TfLiteIntArrayGetSizeInBytes(size), + reinterpret_cast(dims))); + + (*dims)->size = size; + (*dims)->data[0] = x; + if (y > 0) { + (*dims)->data[1] = y; + } + if (z > 0) { + (*dims)->data[2] = z; + } + + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + void* raw; + OpData* op_data = nullptr; + + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + context->AllocatePersistentBuffer(context, sizeof(OpData), &raw); + op_data = reinterpret_cast(raw); + + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + if (m["detections_per_class"].IsNull()) + op_data->detections_per_class = kNumDetectionsPerClass; + else + op_data->detections_per_class = m["detections_per_class"].AsInt32(); + if (m["use_regular_nms"].IsNull()) + op_data->use_regular_non_max_suppression = false; + else + op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool(); + + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = static_cast(node->user_data); + + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + + context->RequestScratchBufferInArena(context, input_box_encodings->dims->data[1], + &op_data->active_candidate_idx); + context->RequestScratchBufferInArena(context, input_box_encodings->dims->data[1]*kNumCoordBox * sizeof(float), + &op_data->decoded_boxes_idx); + context->RequestScratchBufferInArena(context, + input_class_predictions->dims->data[1] * input_class_predictions->dims->data[2] * + sizeof(float), + &op_data->scores_idx); + + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + if (detection_boxes->dims == nullptr) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &detection_boxes->dims, 1, num_detected_boxes, 4)); + } + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + if (detection_classes->dims == nullptr) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &detection_classes->dims, 1, num_detected_boxes)); + } + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + if (detection_scores->dims == nullptr) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &detection_scores->dims, 1, num_detected_boxes)); + } + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + if (num_detections->dims == nullptr) { + TF_LITE_ENSURE_STATUS(AllocateOutDimensions( + context, &num_detections->dims, 1)); + } + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8 x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, + float quant_zero_point, float quant_scale, + int length_box_encoding, + CenterSizeEncoding* box_centersize) { + const uint8* boxes = + GetTensorData(input_box_encodings) + length_box_encoding * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + // See definition of the KeyPointBoxCoder at + // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py + // The first four elements are the box coordinates, which is the same as the + // FastRnnBoxCoder at + // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteTensor* tensor) { + // TODO (chowdhery): check float + const float* tensor_base = GetTensorData(tensor); + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteTensor* tensor) { + // TODO (chowdhery): check float + float* tensor_base = GetTensorData(tensor); + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + CenterSizeEncoding anchor; + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings( + input_box_encodings, idx, + static_cast(input_box_encodings->params.zero_point), + static_cast(input_box_encodings->params.scale), + input_box_encodings->dims->data[2], &box_centersize); + DequantizeBoxEncodings( + input_anchors, idx, + static_cast(input_anchors->params.zero_point), + static_cast(input_anchors->params.scale), kNumCoordBox, + &anchor); + break; + // Float + case kTfLiteFloat32: { + // Please see DequantizeBoxEncodings function for the support detail. + const int box_encoding_idx = idx * input_box_encodings->dims->data[2]; + const float* boxes = + &(GetTensorData(input_box_encodings)[box_encoding_idx]); + box_centersize = *reinterpret_cast(boxes); + anchor = + ReInterpretTensor(input_anchors)[idx]; + break; + } + default: + // Unsupported type. + return kTfLiteError; + } + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + + auto& box = reinterpret_cast(op_data->decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void SelectDetectionsAboveScoreThreshold(const std::vector& values, + const float threshold, + std::vector* keep_values, + std::vector* keep_indices) { + for (int i = 0; i < values.size(); i++) { + if (values[i] >= threshold) { + keep_values->emplace_back(values[i]); + keep_indices->emplace_back(i); + } + } +} + +bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = reinterpret_cast(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const float* decoded_boxes, + const int i, const int j) { + auto& box_i = reinterpret_cast(decoded_boxes)[i]; + auto& box_j = reinterpret_cast(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() prunes out the box locations with high overlap +// before selecting the highest scoring boxes (max_detections in number) +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +// Complexity is O(N^2) pairwise comparison between boxes +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const std::vector& scores, std::vector* selected, + int max_detections) { + + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const int num_boxes = input_box_encodings->dims->data[1]; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + TF_LITE_ENSURE(context, ValidateBoxes(op_data->decoded_boxes, num_boxes)); + + // threshold scores + std::vector keep_indices; + // TODO (chowdhery): Remove the dynamic allocation and replace it + // with temporaries, esp for std::vector + std::vector keep_scores; + SelectDetectionsAboveScoreThreshold( + scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + + int num_scores_kept = keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, + sorted_indices.data()); + const int num_boxes_kept = num_scores_kept; + const int output_size = std::min(num_boxes_kept, max_detections); + selected->clear(); + + int num_active_candidate = num_boxes_kept; + uint8_t* active_box_candidate = op_data->active_box_candidate; + for (int row = 0; row < num_boxes_kept; row++) { + active_box_candidate[row] = 1; + } + + for (int i = 0; i < num_boxes_kept; ++i) { + if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (active_box_candidate[i] == 1) { + selected->push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes_kept; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + op_data->decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + return kTfLiteOk; +} + +// This function implements a regular version of Non Maximal Suppression (NMS) +// for multiple classes where +// 1) we do NMS separately for each class across all anchors and +// 2) keep only the highest anchor scores across all classes +// 3) The worst runtime of the regular NMS is O(K*N^2) +// where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int num_detections_per_class = op_data->detections_per_class; + const int max_detections = op_data->max_detections; + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + // The row index offset is 1 if background class is included and 0 otherwise. + int label_offset = num_classes_with_background - num_classes; + TF_LITE_ENSURE(context, num_detections_per_class > 0); + + // For each class, perform non-max suppression. + std::vector class_scores(num_boxes); + + std::vector box_indices_after_regular_non_max_suppression( + num_boxes + max_detections); + std::vector scores_after_regular_non_max_suppression(num_boxes + + max_detections); + + int size_of_sorted_indices = 0; + std::vector sorted_indices; + sorted_indices.resize(num_boxes + max_detections); + std::vector sorted_values; + sorted_values.resize(max_detections); + + for (int col = 0; col < num_classes; col++) { + for (int row = 0; row < num_boxes; row++) { + // Get scores of boxes corresponding to all anchors for single class + class_scores[row] = + *(scores + row * num_classes_with_background + col + label_offset); + } + // Perform non-maximal suppression on single class + std::vector selected; + TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( + context, node, op_data, class_scores, &selected, + num_detections_per_class)); + // Add selected indices from non-max suppression of boxes in this class + int output_index = size_of_sorted_indices; + for (const auto& selected_index : selected) { + box_indices_after_regular_non_max_suppression[output_index] = + (selected_index * num_classes_with_background + col + label_offset); + scores_after_regular_non_max_suppression[output_index] = + class_scores[selected_index]; + output_index++; + } + // Sort the max scores among the selected indices + // Get the indices for top scores + int num_indices_to_sort = std::min(output_index, max_detections); + DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(), + output_index, num_indices_to_sort, + sorted_indices.data()); + + // Copy values to temporary vectors + for (int row = 0; row < num_indices_to_sort; row++) { + int temp = sorted_indices[row]; + sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp]; + sorted_values[row] = scores_after_regular_non_max_suppression[temp]; + } + // Copy scores and indices from temporary vectors + for (int row = 0; row < num_indices_to_sort; row++) { + box_indices_after_regular_non_max_suppression[row] = sorted_indices[row]; + scores_after_regular_non_max_suppression[row] = sorted_values[row]; + } + size_of_sorted_indices = num_indices_to_sort; + } + + // Allocate output tensors + for (int output_box_index = 0; output_box_index < max_detections; + output_box_index++) { + if (output_box_index < size_of_sorted_indices) { + const int anchor_index = floor( + box_indices_after_regular_non_max_suppression[output_box_index] / + num_classes_with_background); + const int class_index = + box_indices_after_regular_non_max_suppression[output_box_index] - + anchor_index * num_classes_with_background - label_offset; + const float selected_score = + scores_after_regular_non_max_suppression[output_box_index]; + // detection_boxes + ReInterpretTensor(detection_boxes)[output_box_index] = + reinterpret_cast(op_data->decoded_boxes)[anchor_index]; + // detection_classes + GetTensorData(detection_classes)[output_box_index] = class_index; + // detection_scores + GetTensorData(detection_scores)[output_box_index] = selected_score; + } else { + ReInterpretTensor( + detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; + // detection_classes + GetTensorData(detection_classes)[output_box_index] = 0.0f; + // detection_scores + GetTensorData(detection_scores)[output_box_index] = 0.0f; + } + } + GetTensorData(num_detections)[0] = size_of_sorted_indices; + box_indices_after_regular_non_max_suppression.clear(); + scores_after_regular_non_max_suppression.clear(); + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + // The row index offset is 1 if background class is included and 0 otherwise. + int label_offset = num_classes_with_background - num_classes; + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_classes); + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices.data() + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + + // Perform non-maximal suppression on max scores + std::vector selected; + TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( + context, node, op_data, max_scores, &selected, op_data->max_detections)); + + // Allocate output tensors + int output_box_index = 0; + + for (const auto& selected_index : selected) { + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices.data() + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + + // detection_boxes + ReInterpretTensor(detection_boxes)[box_offset] = + reinterpret_cast(op_data->decoded_boxes)[selected_index]; + + // detection_classes + GetTensorData(detection_classes)[box_offset] = class_indices[col]; + + // detection_scores + GetTensorData(detection_scores)[box_offset] = + box_scores[class_indices[col]]; + + output_box_index++; + } + } + + GetTensorData(num_detections)[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + float* scores) { + float quant_zero_point = + static_cast(input_class_predictions->params.zero_point); + float quant_scale = static_cast(input_class_predictions->params.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8* scores_quant = GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1)); + TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes)); + + const float* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + float* temporary_scores = op_data->scores; + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = GetTensorData(input_class_predictions); + break; + default: + // Unsupported type. + return kTfLiteError; + } + + if (op_data->use_regular_non_max_suppression) + TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(context, node, op_data, scores)); + else + TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(chowdhery): Generalize for any batch size + TF_LITE_ENSURE(context, (kBatchSize == 1)); + + // Set up scratch buffers + void *raw; + auto* op_data = static_cast(node->user_data); + raw = context->GetScratchBuffer(context, op_data->active_candidate_idx); + op_data->active_box_candidate = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->decoded_boxes_idx); + op_data->decoded_boxes = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->scores_idx); + op_data->scores = reinterpret_cast(raw); + + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data)); + + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data)); + + // TODO(chowdhery): Generalize for any batch size + + return kTfLiteOk; +} + +} // namespace detection_postprocess + +TfLiteRegistration* Register_DETECTION_POSTPROCESS() { + static TfLiteRegistration r = { + detection_postprocess::Init, detection_postprocess::Free, + detection_postprocess::Prepare, detection_postprocess::Eval}; + return &r; +} + +} // namespace custom +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc new file mode 100644 index 00000000000..d4fa5af0dca --- /dev/null +++ b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc @@ -0,0 +1,469 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +// Common inputs and outputs. + +static const int kInputShape1[] = {3, 1, 6, 4}; +static const int kInputShape2[] = {3, 1, 6, 3}; +static const int kInputShape3[] = {2, 6, 4}; +static const int kOutputShape1[] = {3, 1, 3, 4}; +static const int kOutputShape2[] = {2, 1, 3}; +static const int kOutputShape3[] = {2, 1, 3}; +static const int kOutputShape4[] = {1, 1}; + +// six boxes in center-size encoding +static const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, // box #1 + 0.0, 1.0, 0.0, 0.0, // box #2 + 0.0, -1.0, 0.0, 0.0, // box #3 + 0.0, 0.0, 0.0, 0.0, // box #4 + 0.0, 1.0, 0.0, 0.0, // box #5 + 0.0, 0.0, 0.0, 0.0 // box #6 +}; + +// class scores - two classes with background +static const float kInputData2[] = {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., + .5, .4, 0., .3, .2}; + +// six anchors in center-size encoding +static const float kInputData3[] = { + 0.5, 0.5, 1.0, 1.0, // anchor #1 + 0.5, 0.5, 1.0, 1.0, // anchor #2 + 0.5, 0.5, 1.0, 1.0, // anchor #3 + 0.5, 10.5, 1.0, 1.0, // anchor #4 + 0.5, 10.5, 1.0, 1.0, // anchor #5 + 0.5, 100.5, 1.0, 1.0 // anchor #6 +}; +// Same boxes in box-corner encoding: +// { 0.0, 0.0, 1.0, 1.0, +// 0.0, 0.1, 1.0, 1.1, +// 0.0, -0.1, 1.0, 0.9, +// 0.0, 10.0, 1.0, 11.0, +// 0.0, 10.1, 1.0, 11.1, +// 0.0, 100.0, 1.0, 101.0} + +static const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}; +static const float kGolden2[] = {1, 0, 0}; +static const float kGolden3[] = {0.95, 0.9, 0.3}; +static const float kGolden4[] = {3.0}; + + +void TestDetectionPostprocess(const int* input_dims_data1, const float* input_data1, + const int* input_dims_data2, const float* input_data2, + const int* input_dims_data3, const float* input_data3, + const int* output_dims_data1, float* output_data1, + const int* output_dims_data2, float* output_data2, + const int* output_dims_data3, float* output_data3, + const int* output_dims_data4, float* output_data4, + const float* golden1, const float* golden2, + const float* golden3, const float* golden4, + const float tolerance, bool use_regular_nms, //TODO 4 tolerance + uint8_t* input_data_quantized1=nullptr, + uint8_t* input_data_quantized2=nullptr, + uint8_t* input_data_quantized3=nullptr, + const float input_min1=0, const float input_max1=0, + const float input_min2=0, const float input_max2=0, + const float input_min3=0, const float input_max3=0) { + TfLiteIntArray* input_dims1 = IntArrayFromInts(input_dims_data1); + TfLiteIntArray* input_dims2 = IntArrayFromInts(input_dims_data2); + TfLiteIntArray* input_dims3 = IntArrayFromInts(input_dims_data3); + TfLiteIntArray* output_dims1 = IntArrayFromInts(output_dims_data1); + TfLiteIntArray* output_dims2 = IntArrayFromInts(output_dims_data2); + TfLiteIntArray* output_dims3 = IntArrayFromInts(output_dims_data3); + TfLiteIntArray* output_dims4 = IntArrayFromInts(output_dims_data4); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 4; + constexpr int tensors_size = inputs_size + outputs_size; + + TfLiteTensor tensors[tensors_size]; + if (input_min1 != 0 || input_max1 != 0 || + input_min2 != 0 || input_max2 != 0 || + input_min3 != 0 || input_max3 != 0) { + const float input_scale1 = ScaleFromMinMax(input_min1, input_max1); + const int input_zero_point1 = ZeroPointFromMinMax(input_min1, input_max1); + const float input_scale2 = ScaleFromMinMax(input_min2, input_max2); + const int input_zero_point2 = ZeroPointFromMinMax(input_min2, input_max2); + const float input_scale3 = ScaleFromMinMax(input_min3, input_max3); + const int input_zero_point3 = ZeroPointFromMinMax(input_min3, input_max3); + + tensors[0] = CreateQuantizedTensor(input_data1, input_data_quantized1,input_dims1, + input_scale1, input_zero_point1); + tensors[1] = CreateQuantizedTensor(input_data2, input_data_quantized2, input_dims2, + input_scale1, input_zero_point2); + tensors[2] = CreateQuantizedTensor(input_data3, input_data_quantized3, input_dims3, + input_scale3, input_zero_point3); + } + else { + tensors[0] = CreateFloatTensor(input_data1, input_dims1); + tensors[1] = CreateFloatTensor(input_data2, input_dims2); + tensors[2] = CreateFloatTensor(input_data3, input_dims3); + } + tensors[3] = CreateFloatTensor(output_data1, output_dims1); + tensors[4] = CreateFloatTensor(output_data2, output_dims2); + tensors[5] = CreateFloatTensor(output_data3, output_dims3); + tensors[6] = CreateFloatTensor(output_data4, output_dims4); + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Int("detections_per_class", 1); + fbb.Bool("use_regular_nms", use_regular_nms); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + + ::tflite::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp("TFLite_Detection_PostProcess"); + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, (const char*)fbb.GetBuffer().data(), fbb.GetBuffer().size()); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {4, 3, 4, 5, 6}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = nullptr; + node.user_data = user_data; + node.builtin_data = nullptr; + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + + // Output dimensions should not be undefined after Prepare + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[3].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[4].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[5].dims); + TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[6].dims); + + const int output_elements_count1 = tensors[3].dims->size; + const int output_elements_count2 = tensors[4].dims->size; + const int output_elements_count3 = tensors[5].dims->size; + const int output_elements_count4 = tensors[6].dims->size; + + if (registration->free) { + registration->free(&context, user_data); + } + + for (int i = 0; i < output_elements_count1; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden1[i], output_data1[i], tolerance); + } + for (int i = 0; i < output_elements_count2; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden2[i], output_data2[i], tolerance); + } + for (int i = 0; i < output_elements_count3; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden3[i], output_data3[i], tolerance); + } + for (int i = 0; i < output_elements_count4; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(golden4[i], output_data4[i], tolerance); + } +} +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const int kInputElements1 = tflite::testing::kInputShape1[1] * tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements1+10]; + uint8_t input_data_quantized3[kInputElements1+10]; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ false, + input_data_quantized1, input_data_quantized2, input_data_quantized3, + /* input1 min/max*/-1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatRegularNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + kGolden1, tflite::testing::kGolden2, kGolden3, kGolden4, + /* tolerance */ 1e-1, /* Use regular NMS: */ true); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedRegularNMS) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + const int kInputElements1 = tflite::testing::kInputShape1[1] * tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements1+10]; + uint8_t input_data_quantized3[kInputElements1+10]; + + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + kGolden1, tflite::testing::kGolden2, + kGolden3, kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ true, + input_data_quantized1, input_data_quantized2, input_data_quantized3, + /* input1 min/max*/-1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + const int kInputShape2[] = {3, 1, 6, 2}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, .93, .95, .5, .4, .3, .2}; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(kInputShape1, kInputData1, kInputShape2, kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatRegularNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape2[] = {3, 1, 6, 2}; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, .93, .95, .5, .4, .3, .2}; + + const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; + const float kGolden3[] = {0.95, 0.9, 0.0}; + const float kGolden4[] = {2.0}; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + kInputShape2, kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + kGolden1, tflite::testing::kGolden2, kGolden3, kGolden4, + /* tolerance */ 1e-1, /* Use regular NMS: */ true); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSWithBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(kInputShape1, kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMSwithNoBackgroundClassAndKeypoints) { + const int kInputShape1[] = {3, 1, 6, 5}; + const int kInputShape2[] = {3, 1, 6, 2}; + + // six boxes in center-size encoding + const float kInputData1[] = { + 0.0, 0.0, 0.0, 0.0, 1.0, // box #1 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #2 + 0.0, -1.0, 0.0, 0.0, 1.0, // box #3 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #4 + 0.0, 1.0, 0.0, 0.0, 1.0, // box #5 + 0.0, 0.0, 0.0, 0.0, 1.0, // box #6 + }; + + // class scores - two classes without background + const float kInputData2[] = {.9, .8, .75, .72, .6, .5, .93, .95, .5, .4, .3, .2}; + + const int kInputElements1 = tflite::testing::kInputShape1[1] * tflite::testing::kInputShape1[2] * + tflite::testing::kInputShape1[3]; + const int kInputElements2 = tflite::testing::kInputShape2[1] * tflite::testing::kInputShape2[2] * + tflite::testing::kInputShape2[3]; + const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; + + uint8_t input_data_quantized1[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements1+10]; + uint8_t input_data_quantized3[kInputElements1+10]; + + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(kInputShape1, kInputData1, kInputShape2, kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + tflite::testing::kOutputShape1, output_data1, + tflite::testing::kOutputShape2, output_data2, + tflite::testing::kOutputShape3, output_data3, + tflite::testing::kOutputShape4, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 3e-1, /* Use regular NMS: */ false, + input_data_quantized1, input_data_quantized2, input_data_quantized3, + /* input1 min/max*/-1.0, 1.0, /* input2 min/max */ 0.0, 1.0, + /* input3 min/max */ 0.0, 100.5); +} + +TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSUndefinedOutputDimensions) { + float output_data1[12]; + float output_data2[3]; + float output_data3[3]; + float output_data4[1]; + + tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, + tflite::testing::kInputShape2, tflite::testing::kInputData2, + tflite::testing::kInputShape3, tflite::testing::kInputData3, + nullptr, output_data1, + nullptr, output_data2, + nullptr, output_data3, + nullptr, output_data4, + tflite::testing::kGolden1, tflite::testing::kGolden2, + tflite::testing::kGolden3, tflite::testing::kGolden4, + /* tolerance */ 0, /* Use regular NMS: */ false); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 23c7ca96408..0ff627c03f5 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -853,7 +853,11 @@ TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, TfLiteTensor result = CreateTensor(dims, is_variable); result.type = kTfLiteFloat32; result.data.f = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(float); + if (dims == nullptr) { + result.bytes = 0; + } else { + result.bytes = ElementCount(*dims) * sizeof(float); + } return result; } From bd774cf7dfe273b2acf14e8322c4fd8de4ea5a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 29 Jun 2020 08:53:48 +0200 Subject: [PATCH 2/9] TFlu: replace std::vector in detection_postprocess --- .../micro/kernels/detection_postprocess.cc | 171 ++++++++++++------ tensorflow/lite/micro/testing/test_utils.cc | 2 +- 2 files changed, 121 insertions(+), 52 deletions(-) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index 5569adce8b3..67054287021 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -107,9 +107,25 @@ struct OpData { int active_candidate_idx; int decoded_boxes_idx; int scores_idx; + int score_buffer_idx; + int keep_scores_idx; + int scores_after_regular_non_max_suppression_idx; + int sorted_values_idx; + int keep_indices_idx; + int sorted_indices_idx; + int buffer_idx; + int selected_idx; uint8_t* active_box_candidate; float* decoded_boxes; float* scores; + float* score_buffer; + float* keep_scores; + float* scores_after_regular_non_max_suppression; + float* sorted_values; + int* keep_indices; + int* sorted_indices; + int* buffer; + int* selected; }; TfLiteStatus AllocateOutDimensions(TfLiteContext* context, TfLiteIntArray** dims, @@ -187,16 +203,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; - context->RequestScratchBufferInArena(context, input_box_encodings->dims->data[1], + // Scratch tensors + context->RequestScratchBufferInArena(context, num_boxes, &op_data->active_candidate_idx); - context->RequestScratchBufferInArena(context, input_box_encodings->dims->data[1]*kNumCoordBox * sizeof(float), + context->RequestScratchBufferInArena(context, num_boxes * kNumCoordBox * sizeof(float), &op_data->decoded_boxes_idx); context->RequestScratchBufferInArena(context, input_class_predictions->dims->data[1] * input_class_predictions->dims->data[2] * sizeof(float), &op_data->scores_idx); + // Additional buffers + context->RequestScratchBufferInArena(context, num_boxes * sizeof(float), &op_data->scores_idx); + context->RequestScratchBufferInArena(context, num_boxes * sizeof(float), &op_data->keep_scores_idx); + context->RequestScratchBufferInArena(context, op_data->max_detections * num_boxes * sizeof(float), + &op_data->scores_after_regular_non_max_suppression_idx); + context->RequestScratchBufferInArena(context, op_data->max_detections * num_boxes * sizeof(float), + &op_data->sorted_values_idx); + context->RequestScratchBufferInArena(context, num_boxes * sizeof(int), &op_data->keep_indices_idx); + context->RequestScratchBufferInArena(context, op_data->max_detections * num_boxes * sizeof(int), + &op_data->sorted_indices_idx); + int buffer_size = std::max(num_classes, op_data->max_detections); + context->RequestScratchBufferInArena(context, buffer_size * num_boxes * sizeof(int), + &op_data->buffer_idx); + buffer_size = std::min(num_boxes, op_data->max_detections); + context->RequestScratchBufferInArena(context, buffer_size * num_boxes * sizeof(int), + &op_data->selected_idx); + // number of detected boxes const int num_detected_boxes = op_data->max_detections * op_data->max_classes_per_detection; @@ -357,16 +393,35 @@ void DecreasingPartialArgSort(const float* values, int num_values, [&values](const int i, const int j) { return values[i] > values[j]; }); } -void SelectDetectionsAboveScoreThreshold(const std::vector& values, - const float threshold, - std::vector* keep_values, - std::vector* keep_indices) { - for (int i = 0; i < values.size(); i++) { +void DecreasingPartialArgSort2(const float* values, int num_values, + int num_to_sort, int* indices, int* ind) { + + std::iota(ind, ind + num_values, 0); + std::partial_sort( + ind, ind + num_to_sort, ind + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); + + std::iota(indices, indices + num_values, 0); + + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); + +} + +int SelectDetectionsAboveScoreThreshold(const float* values, + int size, + const float threshold, + float* keep_values, + int* keep_indices) { + int counter = 0; + for (int i = 0; i < size; i++) { if (values[i] >= threshold) { - keep_values->emplace_back(values[i]); - keep_indices->emplace_back(i); + keep_values[counter++] = values[i]; + keep_indices[i] = i; } } + return counter; } bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) { @@ -405,7 +460,7 @@ float ComputeIntersectionOverUnion(const float* decoded_boxes, // Complexity is O(N^2) pairwise comparison between boxes TfLiteStatus NonMaxSuppressionSingleClassHelper( TfLiteContext* context, TfLiteNode* node, OpData* op_data, - const std::vector& scores, std::vector* selected, + const float* scores, int* selected, int* selected_size, int max_detections) { const TfLiteTensor* input_box_encodings = @@ -425,32 +480,30 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( TF_LITE_ENSURE(context, ValidateBoxes(op_data->decoded_boxes, num_boxes)); // threshold scores - std::vector keep_indices; - // TODO (chowdhery): Remove the dynamic allocation and replace it - // with temporaries, esp for std::vector - std::vector keep_scores; - SelectDetectionsAboveScoreThreshold( - scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + int* keep_indices = op_data->keep_indices; + float* keep_scores = op_data->keep_scores; + int num_scores_kept = SelectDetectionsAboveScoreThreshold( + scores, num_boxes, non_max_suppression_score_threshold, keep_scores, keep_indices); + + int* sorted_indices = op_data->sorted_indices; + + DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept, + sorted_indices); - int num_scores_kept = keep_scores.size(); - std::vector sorted_indices; - sorted_indices.resize(num_scores_kept); - DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, - sorted_indices.data()); const int num_boxes_kept = num_scores_kept; const int output_size = std::min(num_boxes_kept, max_detections); - selected->clear(); + *selected_size = 0; int num_active_candidate = num_boxes_kept; uint8_t* active_box_candidate = op_data->active_box_candidate; + for (int row = 0; row < num_boxes_kept; row++) { active_box_candidate[row] = 1; } - for (int i = 0; i < num_boxes_kept; ++i) { - if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (num_active_candidate == 0 || *selected_size >= output_size) break; if (active_box_candidate[i] == 1) { - selected->push_back(keep_indices[sorted_indices[i]]); + selected[(*selected_size)++] = keep_indices[sorted_indices[i]]; active_box_candidate[i] = 0; num_active_candidate--; } else { @@ -469,6 +522,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( } } } + return kTfLiteOk; } @@ -507,18 +561,14 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, TF_LITE_ENSURE(context, num_detections_per_class > 0); // For each class, perform non-max suppression. - std::vector class_scores(num_boxes); - - std::vector box_indices_after_regular_non_max_suppression( - num_boxes + max_detections); - std::vector scores_after_regular_non_max_suppression(num_boxes + - max_detections); + float* class_scores = op_data->score_buffer; + int* box_indices_after_regular_non_max_suppression = op_data->buffer; + float* scores_after_regular_non_max_suppression = + op_data->scores_after_regular_non_max_suppression; int size_of_sorted_indices = 0; - std::vector sorted_indices; - sorted_indices.resize(num_boxes + max_detections); - std::vector sorted_values; - sorted_values.resize(max_detections); + int* sorted_indices = op_data->sorted_indices; + float* sorted_values = op_data->sorted_values; for (int col = 0; col < num_classes; col++) { for (int row = 0; row < num_boxes; row++) { @@ -527,13 +577,16 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, *(scores + row * num_classes_with_background + col + label_offset); } // Perform non-maximal suppression on single class - std::vector selected; + int selected_size = 0; + int* selected = op_data->selected; TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( - context, node, op_data, class_scores, &selected, + context, node, op_data, class_scores, selected, &selected_size, num_detections_per_class)); // Add selected indices from non-max suppression of boxes in this class int output_index = size_of_sorted_indices; - for (const auto& selected_index : selected) { + for (int i = 0; i < selected_size; i++) { + int selected_index = selected[i]; + box_indices_after_regular_non_max_suppression[output_index] = (selected_index * num_classes_with_background + col + label_offset); scores_after_regular_non_max_suppression[output_index] = @@ -543,9 +596,9 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, // Sort the max scores among the selected indices // Get the indices for top scores int num_indices_to_sort = std::min(output_index, max_detections); - DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(), + DecreasingPartialArgSort(scores_after_regular_non_max_suppression, output_index, num_indices_to_sort, - sorted_indices.data()); + sorted_indices); // Copy values to temporary vectors for (int row = 0; row < num_indices_to_sort; row++) { @@ -590,8 +643,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, } } GetTensorData(num_detections)[0] = size_of_sorted_indices; - box_indices_after_regular_non_max_suppression.clear(); - scores_after_regular_non_max_suppression.clear(); + return kTfLiteOk; } @@ -631,32 +683,33 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); const int num_categories_per_anchor = std::min(max_categories_per_anchor, num_classes); - std::vector max_scores; - max_scores.resize(num_boxes); - std::vector sorted_class_indices; - sorted_class_indices.resize(num_boxes * num_classes); + float* max_scores = op_data->score_buffer; + int* sorted_class_indices = op_data->buffer; for (int row = 0; row < num_boxes; row++) { const float* box_scores = scores + row * num_classes_with_background + label_offset; - int* class_indices = sorted_class_indices.data() + row * num_classes; + int* class_indices = sorted_class_indices + row * num_classes; DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, class_indices); max_scores[row] = box_scores[class_indices[0]]; } // Perform non-maximal suppression on max scores - std::vector selected; + int selected_size = 0; + int* selected = op_data->selected; TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( - context, node, op_data, max_scores, &selected, op_data->max_detections)); + context, node, op_data, max_scores, selected, &selected_size, op_data->max_detections)); // Allocate output tensors int output_box_index = 0; - for (const auto& selected_index : selected) { + for (int i = 0; i < selected_size; i++) { + int selected_index = selected[i]; + const float* box_scores = scores + selected_index * num_classes_with_background + label_offset; const int* class_indices = - sorted_class_indices.data() + selected_index * num_classes; + sorted_class_indices + selected_index * num_classes; for (int col = 0; col < num_categories_per_anchor; ++col) { int box_offset = num_categories_per_anchor * output_box_index + col; @@ -750,6 +803,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_data->decoded_boxes = reinterpret_cast(raw); raw = context->GetScratchBuffer(context, op_data->scores_idx); op_data->scores = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->score_buffer_idx); + op_data->score_buffer = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->keep_scores_idx); + op_data->keep_scores = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->scores_after_regular_non_max_suppression_idx); + op_data->scores_after_regular_non_max_suppression = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->sorted_values_idx); + op_data->sorted_values = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->keep_indices_idx); + op_data->keep_indices = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->sorted_indices_idx); + op_data->sorted_indices = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->buffer_idx); + op_data->buffer = reinterpret_cast(raw); + raw = context->GetScratchBuffer(context, op_data->selected_idx); + op_data->selected = reinterpret_cast(raw); // These two functions correspond to two blocks in the Object Detection model. // In future, we would like to break the custom op in two blocks, which is diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 0bb97854a41..7cb1aa9e270 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -36,7 +36,7 @@ constexpr size_t kBufferAlignment = 16; // We store the pointer to the ith scratch buffer to implement the Request/Get // ScratchBuffer API for the tests. scratch_buffers_[i] will be the ith scratch // buffer and will still be allocated from within raw_arena_. -constexpr int kNumScratchBuffers = 5; +constexpr int kNumScratchBuffers = 12; uint8_t* scratch_buffers_[kNumScratchBuffers]; int scratch_buffer_count_ = 0; From 88cc269d3b7176b5830239f12f9782e1922325f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Tue, 30 Jun 2020 15:53:29 +0200 Subject: [PATCH 3/9] TFlu: Fix detection_postprocess output dimension handling Tensors with undefined dimension will not really have it undefined as the actual flatbuffer tensor. Instead the dimension size will be zero. --- .../micro/kernels/detection_postprocess.cc | 8 +++---- .../kernels/detection_postprocess_test.cc | 23 +++++++++++++++---- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index 67054287021..93410689a16 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -244,7 +244,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) TfLiteTensor* detection_boxes = GetOutput(context, node, kOutputTensorDetectionBoxes); - if (detection_boxes->dims == nullptr) { + if (detection_boxes->dims->size == 0) { TF_LITE_ENSURE_STATUS(AllocateOutDimensions( context, &detection_boxes->dims, 1, num_detected_boxes, 4)); } @@ -252,7 +252,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Output Tensor detection_classes: size is set to (1, num_detected_boxes) TfLiteTensor* detection_classes = GetOutput(context, node, kOutputTensorDetectionClasses); - if (detection_classes->dims == nullptr) { + if (detection_classes->dims->size == 0) { TF_LITE_ENSURE_STATUS(AllocateOutDimensions( context, &detection_classes->dims, 1, num_detected_boxes)); } @@ -260,7 +260,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Output Tensor detection_scores: size is set to (1, num_detected_boxes) TfLiteTensor* detection_scores = GetOutput(context, node, kOutputTensorDetectionScores); - if (detection_scores->dims == nullptr) { + if (detection_scores->dims->size == 0) { TF_LITE_ENSURE_STATUS(AllocateOutDimensions( context, &detection_scores->dims, 1, num_detected_boxes)); } @@ -268,7 +268,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Output Tensor num_detections: size is set to 1 TfLiteTensor* num_detections = GetOutput(context, node, kOutputTensorNumDetections); - if (num_detections->dims == nullptr) { + if (num_detections->dims->size == 0) { TF_LITE_ENSURE_STATUS(AllocateOutDimensions( context, &num_detections->dims, 1)); } diff --git a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc index d4fa5af0dca..6ca72f4f644 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc @@ -80,7 +80,7 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da const int* output_dims_data4, float* output_data4, const float* golden1, const float* golden2, const float* golden3, const float* golden4, - const float tolerance, bool use_regular_nms, //TODO 4 tolerance + const float tolerance, bool use_regular_nms, uint8_t* input_data_quantized1=nullptr, uint8_t* input_data_quantized2=nullptr, uint8_t* input_data_quantized3=nullptr, @@ -90,10 +90,23 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da TfLiteIntArray* input_dims1 = IntArrayFromInts(input_dims_data1); TfLiteIntArray* input_dims2 = IntArrayFromInts(input_dims_data2); TfLiteIntArray* input_dims3 = IntArrayFromInts(input_dims_data3); - TfLiteIntArray* output_dims1 = IntArrayFromInts(output_dims_data1); - TfLiteIntArray* output_dims2 = IntArrayFromInts(output_dims_data2); - TfLiteIntArray* output_dims3 = IntArrayFromInts(output_dims_data3); - TfLiteIntArray* output_dims4 = IntArrayFromInts(output_dims_data4); + TfLiteIntArray* output_dims1 = nullptr; + TfLiteIntArray* output_dims2 = nullptr; + TfLiteIntArray* output_dims3 = nullptr; + TfLiteIntArray* output_dims4 = nullptr; + + // Instance of a zero-length int to pass as tensor dims for a flatbuffer + // Tensor with no shape. + const TfLiteIntArray kZeroLengthIntArray = {0}; + + output_dims1 = output_dims_data1 == nullptr ? const_cast(&kZeroLengthIntArray) : + IntArrayFromInts(output_dims_data1); + output_dims2 = output_dims_data2 == nullptr ? const_cast(&kZeroLengthIntArray) : + IntArrayFromInts(output_dims_data2); + output_dims3 = output_dims_data3 == nullptr ? const_cast(&kZeroLengthIntArray) : + IntArrayFromInts(output_dims_data3); + output_dims4 = output_dims_data4 == nullptr ? const_cast(&kZeroLengthIntArray) : + IntArrayFromInts(output_dims_data4); constexpr int inputs_size = 3; constexpr int outputs_size = 4; From 634888a82f46694e2747ffde745d269b6cdf7c80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Wed, 8 Jul 2020 08:07:08 +0200 Subject: [PATCH 4/9] TFLu: detection_postprocess: fix review comments and build issues --- tensorflow/lite/micro/kernels/BUILD | 1 - .../lite/micro/kernels/detection_postprocess.cc | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 488c6cfe769..f8e302ab0f6 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -176,7 +176,6 @@ tflite_micro_cc_test( "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", - "@flatbuffers", ], ) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index 93410689a16..f55b64f0ba5 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -309,14 +309,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, template T ReInterpretTensor(const TfLiteTensor* tensor) { - // TODO (chowdhery): check float const float* tensor_base = GetTensorData(tensor); return reinterpret_cast(tensor_base); } template T ReInterpretTensor(TfLiteTensor* tensor) { - // TODO (chowdhery): check float float* tensor_base = GetTensorData(tensor); return reinterpret_cast(tensor_base); } @@ -791,7 +789,6 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - // TODO(chowdhery): Generalize for any batch size TF_LITE_ENSURE(context, (kBatchSize == 1)); // Set up scratch buffers @@ -837,17 +834,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // highest scoring non-overlapping boxes. TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data)); - // TODO(chowdhery): Generalize for any batch size - return kTfLiteOk; } } // namespace detection_postprocess TfLiteRegistration* Register_DETECTION_POSTPROCESS() { - static TfLiteRegistration r = { - detection_postprocess::Init, detection_postprocess::Free, - detection_postprocess::Prepare, detection_postprocess::Eval}; + static TfLiteRegistration r = {/*init=*/detection_postprocess::Init, + /*free=*/detection_postprocess::Free, + /*prepare=*/detection_postprocess::Prepare, + /*invoke=*/detection_postprocess::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; return &r; } From 4f1b59c06d3755c2235a6c62ac31c5f2aef0c810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 3 Aug 2020 10:56:47 +0200 Subject: [PATCH 5/9] TFLu: add missing bazel dependency for detection_pp --- tensorflow/lite/micro/kernels/BUILD | 2 ++ tensorflow/lite/micro/kernels/detection_postprocess.cc | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index f8e302ab0f6..9709ff5e7c5 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -115,6 +115,7 @@ cc_library( "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/micro:memory_helpers", "//tensorflow/lite/micro:micro_utils", + "@flatbuffers", ] + select({ "//conditions:default": [], ":xtensa_hifimini": [ @@ -176,6 +177,7 @@ tflite_micro_cc_test( "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", + "@flatbuffers", ], ) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index f55b64f0ba5..1165c35b568 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -136,9 +136,8 @@ TfLiteStatus AllocateOutDimensions(TfLiteContext* context, TfLiteIntArray** dims size = (y > 0) ? size * y : size; size = (z > 0) ? size * z : size; - TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer( - context, TfLiteIntArrayGetSizeInBytes(size), - reinterpret_cast(dims))); + *dims = reinterpret_cast(context->AllocatePersistentBuffer( + context, TfLiteIntArrayGetSizeInBytes(size))); (*dims)->size = size; (*dims)->data[0] = x; @@ -153,14 +152,13 @@ TfLiteStatus AllocateOutDimensions(TfLiteContext* context, TfLiteIntArray** dims } void* Init(TfLiteContext* context, const char* buffer, size_t length) { - void* raw; OpData* op_data = nullptr; const uint8_t* buffer_t = reinterpret_cast(buffer); const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - context->AllocatePersistentBuffer(context, sizeof(OpData), &raw); - op_data = reinterpret_cast(raw); + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + op_data = reinterpret_cast(context->AllocatePersistentBuffer(context, sizeof(OpData))); op_data->max_detections = m["max_detections"].AsInt32(); op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); From d5909ff1a03e990df2351b679aa26d8b5732e694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Tue, 4 Aug 2020 10:43:42 +0200 Subject: [PATCH 6/9] TFLu: Fix for internal build error in detection_pp --- tensorflow/lite/micro/kernels/detection_postprocess.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index 1165c35b568..173766712d4 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -278,7 +278,7 @@ class Dequantizer { public: Dequantizer(int zero_point, float scale) : zero_point_(zero_point), scale_(scale) {} - float operator()(uint8 x) { + float operator()(uint8_t x) { return (static_cast(x) - zero_point_) * scale_; } @@ -291,8 +291,8 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, float quant_zero_point, float quant_scale, int length_box_encoding, CenterSizeEncoding* box_centersize) { - const uint8* boxes = - GetTensorData(input_box_encodings) + length_box_encoding * idx; + const uint8_t* boxes = + GetTensorData(input_box_encodings) + length_box_encoding * idx; Dequantizer dequantize(quant_zero_point, quant_scale); // See definition of the KeyPointBoxCoder at // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py @@ -737,7 +737,7 @@ void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, static_cast(input_class_predictions->params.zero_point); float quant_scale = static_cast(input_class_predictions->params.scale); Dequantizer dequantize(quant_zero_point, quant_scale); - const uint8* scores_quant = GetTensorData(input_class_predictions); + const uint8_t* scores_quant = GetTensorData(input_class_predictions); for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { scores[idx] = dequantize(scores_quant[idx]); } From 2394ff5924cc0365dd7de4a2b26387a65c6efcff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 6 Aug 2020 08:10:51 +0200 Subject: [PATCH 7/9] TFLu: Fix warnings in detection_pp test --- tensorflow/lite/micro/kernels/BUILD | 2 +- .../micro/kernels/detection_postprocess_test.cc | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 9709ff5e7c5..77f4ef32447 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -174,8 +174,8 @@ tflite_micro_cc_test( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro/testing:micro_test", "@flatbuffers", ], diff --git a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc index 6ca72f4f644..e90144bfd39 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc @@ -126,7 +126,7 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da tensors[0] = CreateQuantizedTensor(input_data1, input_data_quantized1,input_dims1, input_scale1, input_zero_point1); tensors[1] = CreateQuantizedTensor(input_data2, input_data_quantized2, input_dims2, - input_scale1, input_zero_point2); + input_scale2, input_zero_point2); tensors[2] = CreateQuantizedTensor(input_data3, input_data_quantized3, input_dims3, input_scale3, input_zero_point3); } @@ -254,8 +254,8 @@ TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMS) { const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; uint8_t input_data_quantized1[kInputElements1+10]; - uint8_t input_data_quantized2[kInputElements1+10]; - uint8_t input_data_quantized3[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements2+10]; + uint8_t input_data_quantized3[kInputElements3+10]; tflite::testing::TestDetectionPostprocess(tflite::testing::kInputShape1, tflite::testing::kInputData1, tflite::testing::kInputShape2, tflite::testing::kInputData2, @@ -304,8 +304,8 @@ TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedRegularNMS) { const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; uint8_t input_data_quantized1[kInputElements1+10]; - uint8_t input_data_quantized2[kInputElements1+10]; - uint8_t input_data_quantized3[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements2+10]; + uint8_t input_data_quantized3[kInputElements3+10]; const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 0.0, 0.0}; const float kGolden3[] = {0.95, 0.9, 0.0}; @@ -439,8 +439,8 @@ TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMSwithNoBackgroundClassAndK const int kInputElements3 = tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2]; uint8_t input_data_quantized1[kInputElements1+10]; - uint8_t input_data_quantized2[kInputElements1+10]; - uint8_t input_data_quantized3[kInputElements1+10]; + uint8_t input_data_quantized2[kInputElements2+10]; + uint8_t input_data_quantized3[kInputElements3+10]; float output_data1[12]; float output_data2[3]; From 4bfb202b677705455dcb12b3245be87ae7d326a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 7 Aug 2020 16:56:05 +0200 Subject: [PATCH 8/9] TFLu: Switch to TFLiteEvalTensor in detection_pp --- tensorflow/lite/micro/kernels/BUILD | 1 + .../micro/kernels/detection_postprocess.cc | 159 ++++++++++-------- .../kernels/detection_postprocess_test.cc | 48 ++---- .../lite/micro/kernels/kernel_runner.cc | 4 +- tensorflow/lite/micro/kernels/kernel_runner.h | 4 +- 5 files changed, 109 insertions(+), 107 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 77f4ef32447..f3983583642 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -173,6 +173,7 @@ tflite_micro_cc_test( "detection_postprocess_test.cc", ], deps = [ + ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index 173766712d4..f1bf89f6004 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/micro_utils.h" @@ -103,7 +104,7 @@ struct OpData { bool use_regular_non_max_suppression; CenterSizeEncoding scale_values; - // Scratch buffers + // Scratch buffers indexes int active_candidate_idx; int decoded_boxes_idx; int scores_idx; @@ -115,6 +116,9 @@ struct OpData { int sorted_indices_idx; int buffer_idx; int selected_idx; + + // These are just temporary pointers to scratch buffers, set in each invocation of eval + // OpData can then be used to pass them around and the number of parameters can be reduced uint8_t* active_box_candidate; float* decoded_boxes; float* scores; @@ -126,6 +130,11 @@ struct OpData { int* sorted_indices; int* buffer; int* selected; + + // Cached tensor scale and zero point values for quantized operations + TfLiteQuantizationParams input_box_encodings; + TfLiteQuantizationParams input_class_predictions; + TfLiteQuantizationParams input_anchors; }; TfLiteStatus AllocateOutDimensions(TfLiteContext* context, TfLiteIntArray** dims, @@ -204,6 +213,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int num_boxes = input_box_encodings->dims->data[1]; const int num_classes = op_data->num_classes; + op_data->input_box_encodings.scale = input_box_encodings->params.scale; + op_data->input_box_encodings.zero_point = input_box_encodings->params.zero_point; + op_data->input_class_predictions.scale = input_class_predictions->params.scale; + op_data->input_class_predictions.zero_point = input_class_predictions->params.zero_point; + op_data->input_anchors.scale = input_anchors->params.scale; + op_data->input_anchors.zero_point = input_anchors->params.zero_point; + // Scratch tensors context->RequestScratchBufferInArena(context, num_boxes, &op_data->active_candidate_idx); @@ -287,12 +303,12 @@ class Dequantizer { float scale_; }; -void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, +void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings, int idx, float quant_zero_point, float quant_scale, int length_box_encoding, CenterSizeEncoding* box_centersize) { const uint8_t* boxes = - GetTensorData(input_box_encodings) + length_box_encoding * idx; + tflite::micro::GetTensorData(input_box_encodings) + length_box_encoding * idx; Dequantizer dequantize(quant_zero_point, quant_scale); // See definition of the KeyPointBoxCoder at // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py @@ -306,27 +322,27 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, } template -T ReInterpretTensor(const TfLiteTensor* tensor) { - const float* tensor_base = GetTensorData(tensor); +T ReInterpretTensor(const TfLiteEvalTensor* tensor) { + const float* tensor_base = tflite::micro::GetTensorData(tensor); return reinterpret_cast(tensor_base); } template -T ReInterpretTensor(TfLiteTensor* tensor) { - float* tensor_base = GetTensorData(tensor); +T ReInterpretTensor(TfLiteEvalTensor* tensor) { + float* tensor_base = tflite::micro::GetTensorData(tensor); return reinterpret_cast(tensor_base); } TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, OpData* op_data) { // Parse input tensor boxencodings - const TfLiteTensor* input_box_encodings = - GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); const int num_boxes = input_box_encodings->dims->data[1]; TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox); - const TfLiteTensor* input_anchors = - GetInput(context, node, kInputTensorAnchors); + const TfLiteEvalTensor* input_anchors = + tflite::micro::GetEvalInput(context, node, kInputTensorAnchors); // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors CenterSizeEncoding box_centersize; @@ -338,13 +354,13 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, case kTfLiteUInt8: DequantizeBoxEncodings( input_box_encodings, idx, - static_cast(input_box_encodings->params.zero_point), - static_cast(input_box_encodings->params.scale), + static_cast(op_data->input_box_encodings.zero_point), + static_cast(op_data->input_box_encodings.scale), input_box_encodings->dims->data[2], &box_centersize); DequantizeBoxEncodings( input_anchors, idx, - static_cast(input_anchors->params.zero_point), - static_cast(input_anchors->params.scale), kNumCoordBox, + static_cast(op_data->input_anchors.zero_point), + static_cast(op_data->input_anchors.scale), kNumCoordBox, &anchor); break; // Float @@ -352,7 +368,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, // Please see DequantizeBoxEncodings function for the support detail. const int box_encoding_idx = idx * input_box_encodings->dims->data[2]; const float* boxes = - &(GetTensorData(input_box_encodings)[box_encoding_idx]); + &(tflite::micro::GetTensorData(input_box_encodings)[box_encoding_idx]); box_centersize = *reinterpret_cast(boxes); anchor = ReInterpretTensor(input_anchors)[idx]; @@ -459,8 +475,8 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( const float* scores, int* selected, int* selected_size, int max_detections) { - const TfLiteTensor* input_box_encodings = - GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); const int num_boxes = input_box_encodings->dims->data[1]; const float non_max_suppression_score_threshold = op_data->non_max_suppression_score_threshold; @@ -533,18 +549,18 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, TfLiteNode* node, OpData* op_data, const float* scores) { - const TfLiteTensor* input_box_encodings = - GetInput(context, node, kInputTensorBoxEncodings); - const TfLiteTensor* input_class_predictions = - GetInput(context, node, kInputTensorClassPredictions); - TfLiteTensor* detection_boxes = - GetOutput(context, node, kOutputTensorDetectionBoxes); - TfLiteTensor* detection_classes = - GetOutput(context, node, kOutputTensorDetectionClasses); - TfLiteTensor* detection_scores = - GetOutput(context, node, kOutputTensorDetectionScores); - TfLiteTensor* num_detections = - GetOutput(context, node, kOutputTensorNumDetections); + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); + TfLiteEvalTensor* detection_boxes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteEvalTensor* detection_classes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionClasses); + TfLiteEvalTensor* detection_scores = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores); + TfLiteEvalTensor* num_detections = + tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections); const int num_boxes = input_box_encodings->dims->data[1]; const int num_classes = op_data->num_classes; @@ -626,19 +642,19 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, ReInterpretTensor(detection_boxes)[output_box_index] = reinterpret_cast(op_data->decoded_boxes)[anchor_index]; // detection_classes - GetTensorData(detection_classes)[output_box_index] = class_index; + tflite::micro::GetTensorData(detection_classes)[output_box_index] = class_index; // detection_scores - GetTensorData(detection_scores)[output_box_index] = selected_score; + tflite::micro::GetTensorData(detection_scores)[output_box_index] = selected_score; } else { ReInterpretTensor( detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; // detection_classes - GetTensorData(detection_classes)[output_box_index] = 0.0f; + tflite::micro::GetTensorData(detection_classes)[output_box_index] = 0.0f; // detection_scores - GetTensorData(detection_scores)[output_box_index] = 0.0f; + tflite::micro::GetTensorData(detection_scores)[output_box_index] = 0.0f; } } - GetTensorData(num_detections)[0] = size_of_sorted_indices; + tflite::micro::GetTensorData(num_detections)[0] = size_of_sorted_indices; return kTfLiteOk; } @@ -654,19 +670,19 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, TfLiteNode* node, OpData* op_data, const float* scores) { - const TfLiteTensor* input_box_encodings = - GetInput(context, node, kInputTensorBoxEncodings); - const TfLiteTensor* input_class_predictions = - GetInput(context, node, kInputTensorClassPredictions); - TfLiteTensor* detection_boxes = - GetOutput(context, node, kOutputTensorDetectionBoxes); + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); + TfLiteEvalTensor* detection_boxes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes); - TfLiteTensor* detection_classes = - GetOutput(context, node, kOutputTensorDetectionClasses); - TfLiteTensor* detection_scores = - GetOutput(context, node, kOutputTensorDetectionScores); - TfLiteTensor* num_detections = - GetOutput(context, node, kOutputTensorNumDetections); + TfLiteEvalTensor* detection_classes = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionClasses); + TfLiteEvalTensor* detection_scores = + tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores); + TfLiteEvalTensor* num_detections = + tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections); const int num_boxes = input_box_encodings->dims->data[1]; const int num_classes = op_data->num_classes; @@ -715,29 +731,30 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, reinterpret_cast(op_data->decoded_boxes)[selected_index]; // detection_classes - GetTensorData(detection_classes)[box_offset] = class_indices[col]; + tflite::micro::GetTensorData(detection_classes)[box_offset] = class_indices[col]; // detection_scores - GetTensorData(detection_scores)[box_offset] = + tflite::micro::GetTensorData(detection_scores)[box_offset] = box_scores[class_indices[col]]; output_box_index++; } } - GetTensorData(num_detections)[0] = output_box_index; + tflite::micro::GetTensorData(num_detections)[0] = output_box_index; return kTfLiteOk; } -void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, +void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions, const int num_boxes, const int num_classes_with_background, - float* scores) { + float* scores, + OpData* op_data) { float quant_zero_point = - static_cast(input_class_predictions->params.zero_point); - float quant_scale = static_cast(input_class_predictions->params.scale); + static_cast(op_data->input_class_predictions.zero_point); + float quant_scale = static_cast(op_data->input_class_predictions.scale); Dequantizer dequantize(quant_zero_point, quant_scale); - const uint8_t* scores_quant = GetTensorData(input_class_predictions); + const uint8_t* scores_quant = tflite::micro::GetTensorData(input_class_predictions); for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { scores[idx] = dequantize(scores_quant[idx]); } @@ -746,10 +763,10 @@ void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, TfLiteNode* node, OpData* op_data) { // Get the input tensors - const TfLiteTensor* input_box_encodings = - GetInput(context, node, kInputTensorBoxEncodings); - const TfLiteTensor* input_class_predictions = - GetInput(context, node, kInputTensorClassPredictions); + const TfLiteEvalTensor* input_box_encodings = + tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings); + const TfLiteEvalTensor* input_class_predictions = + tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions); const int num_boxes = input_box_encodings->dims->data[1]; const int num_classes = op_data->num_classes; @@ -767,11 +784,12 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, case kTfLiteUInt8: { float* temporary_scores = op_data->scores; DequantizeClassPredictions(input_class_predictions, num_boxes, - num_classes_with_background, temporary_scores); + num_classes_with_background, temporary_scores, + op_data); scores = temporary_scores; } break; case kTfLiteFloat32: - scores = GetTensorData(input_class_predictions); + scores = tflite::micro::GetTensorData(input_class_predictions); break; default: // Unsupported type. @@ -837,16 +855,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace detection_postprocess -TfLiteRegistration* Register_DETECTION_POSTPROCESS() { - static TfLiteRegistration r = {/*init=*/detection_postprocess::Init, - /*free=*/detection_postprocess::Free, - /*prepare=*/detection_postprocess::Prepare, - /*invoke=*/detection_postprocess::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; +TfLiteRegistration Register_DETECTION_POSTPROCESS() { + return {/*init=*/detection_postprocess::Init, + /*free=*/detection_postprocess::Free, + /*prepare=*/detection_postprocess::Prepare, + /*invoke=*/detection_postprocess::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; } } // namespace custom diff --git a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc index e90144bfd39..57df940f011 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess_test.cc @@ -16,7 +16,8 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" + #include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/test_utils.h" @@ -95,17 +96,16 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da TfLiteIntArray* output_dims3 = nullptr; TfLiteIntArray* output_dims4 = nullptr; - // Instance of a zero-length int to pass as tensor dims for a flatbuffer - // Tensor with no shape. - const TfLiteIntArray kZeroLengthIntArray = {0}; + const int zero_length_int_array_data[] = {0}; + TfLiteIntArray *zero_length_int_array = IntArrayFromInts(zero_length_int_array_data); - output_dims1 = output_dims_data1 == nullptr ? const_cast(&kZeroLengthIntArray) : + output_dims1 = output_dims_data1 == nullptr ? const_cast(zero_length_int_array) : IntArrayFromInts(output_dims_data1); - output_dims2 = output_dims_data2 == nullptr ? const_cast(&kZeroLengthIntArray) : + output_dims2 = output_dims_data2 == nullptr ? const_cast(zero_length_int_array) : IntArrayFromInts(output_dims_data2); - output_dims3 = output_dims_data3 == nullptr ? const_cast(&kZeroLengthIntArray) : + output_dims3 = output_dims_data3 == nullptr ? const_cast(zero_length_int_array) : IntArrayFromInts(output_dims_data3); - output_dims4 = output_dims_data4 == nullptr ? const_cast(&kZeroLengthIntArray) : + output_dims4 = output_dims_data4 == nullptr ? const_cast(zero_length_int_array) : IntArrayFromInts(output_dims_data4); constexpr int inputs_size = 3; @@ -162,33 +162,19 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da ::tflite::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp("TFLite_Detection_PostProcess"); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - void* user_data = nullptr; - if (registration->init) { - user_data = registration->init(&context, (const char*)fbb.GetBuffer().data(), fbb.GetBuffer().size()); - } - int inputs_array_data[] = {3, 0, 1, 2}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {4, 3, 4, 5, 6}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - node.temporaries = nullptr; - node.user_data = user_data; - node.builtin_data = nullptr; - node.custom_initial_data = nullptr; - node.custom_initial_data_size = 0; - node.delegate = nullptr; - if (registration->prepare) { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); - } - TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + micro::KernelRunner runner( + *registration, tensors, tensors_size, inputs_array, outputs_array, + nullptr, micro_test::reporter); + + const char* init_data = reinterpret_cast(fbb.GetBuffer().data()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare(init_data, fbb.GetBuffer().size())); // Output dimensions should not be undefined after Prepare TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[3].dims); @@ -196,15 +182,13 @@ void TestDetectionPostprocess(const int* input_dims_data1, const float* input_da TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[5].dims); TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[6].dims); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + const int output_elements_count1 = tensors[3].dims->size; const int output_elements_count2 = tensors[4].dims->size; const int output_elements_count3 = tensors[5].dims->size; const int output_elements_count4 = tensors[6].dims->size; - if (registration->free) { - registration->free(&context, user_data); - } - for (int i = 0; i < output_elements_count1; ++i) { TF_LITE_MICRO_EXPECT_NEAR(golden1[i], output_data1[i], tolerance); } diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc index cef6c01cf45..fc8a161b9d9 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.cc +++ b/tensorflow/lite/micro/kernels/kernel_runner.cc @@ -52,9 +52,9 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration, node_.builtin_data = builtin_data; } -TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) { +TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data, size_t length) { if (registration_.init) { - node_.user_data = registration_.init(&context_, init_data, /*length=*/0); + node_.user_data = registration_.init(&context_, init_data, length); } if (registration_.prepare) { TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_)); diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h index 45d107e7a37..119150c9ff5 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/tensorflow/lite/micro/kernels/kernel_runner.h @@ -39,7 +39,7 @@ class KernelRunner { // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any // exceptions will be reported through the error_reporter and returned as a // status code here. - TfLiteStatus InitAndPrepare(const char* init_data = nullptr); + TfLiteStatus InitAndPrepare(const char* init_data = nullptr, size_t length = 0); // Calls init, prepare, and invoke on a given TfLiteRegistration pointer. // After successful invoke, results will be available in the output tensor as @@ -60,7 +60,7 @@ class KernelRunner { ...); private: - static constexpr int kNumScratchBuffers_ = 5; + static constexpr int kNumScratchBuffers_ = 12; static constexpr int kKernelRunnerBufferSize_ = 10000; static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_]; From 10cfb4a14bdc1d2145fe50cdddd94819e4226bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Tue, 11 Aug 2020 09:47:00 +0200 Subject: [PATCH 9/9] TFLu: Remove temporary scratch pointers in OpData in detection_pp --- .../micro/kernels/detection_postprocess.cc | 103 ++++++++---------- 1 file changed, 43 insertions(+), 60 deletions(-) diff --git a/tensorflow/lite/micro/kernels/detection_postprocess.cc b/tensorflow/lite/micro/kernels/detection_postprocess.cc index f1bf89f6004..039afa2661e 100644 --- a/tensorflow/lite/micro/kernels/detection_postprocess.cc +++ b/tensorflow/lite/micro/kernels/detection_postprocess.cc @@ -117,20 +117,6 @@ struct OpData { int buffer_idx; int selected_idx; - // These are just temporary pointers to scratch buffers, set in each invocation of eval - // OpData can then be used to pass them around and the number of parameters can be reduced - uint8_t* active_box_candidate; - float* decoded_boxes; - float* scores; - float* score_buffer; - float* keep_scores; - float* scores_after_regular_non_max_suppression; - float* sorted_values; - int* keep_indices; - int* sorted_indices; - int* buffer; - int* selected; - // Cached tensor scale and zero point values for quantized operations TfLiteQuantizationParams input_box_encodings; TfLiteQuantizationParams input_class_predictions; @@ -388,7 +374,9 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * anchor.w; - auto& box = reinterpret_cast(op_data->decoded_boxes)[idx]; + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + auto& box = reinterpret_cast(decoded_boxes)[idx]; box.ymin = ycenter - half_h; box.xmin = xcenter - half_w; box.ymax = ycenter + half_h; @@ -489,15 +477,20 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && (intersection_over_union_threshold <= 1.0f)); // Validate boxes - TF_LITE_ENSURE(context, ValidateBoxes(op_data->decoded_boxes, num_boxes)); + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); + + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); // threshold scores - int* keep_indices = op_data->keep_indices; - float* keep_scores = op_data->keep_scores; + int* keep_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->keep_indices_idx)); + float* keep_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->keep_scores_idx)); int num_scores_kept = SelectDetectionsAboveScoreThreshold( scores, num_boxes, non_max_suppression_score_threshold, keep_scores, keep_indices); - - int* sorted_indices = op_data->sorted_indices; + int* sorted_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_indices_idx)); DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept, sorted_indices); @@ -507,7 +500,8 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( *selected_size = 0; int num_active_candidate = num_boxes_kept; - uint8_t* active_box_candidate = op_data->active_box_candidate; + uint8_t* active_box_candidate = reinterpret_cast( + context->GetScratchBuffer(context, op_data->active_candidate_idx)); for (int row = 0; row < num_boxes_kept; row++) { active_box_candidate[row] = 1; @@ -524,7 +518,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( for (int j = i + 1; j < num_boxes_kept; ++j) { if (active_box_candidate[j] == 1) { float intersection_over_union = ComputeIntersectionOverUnion( - op_data->decoded_boxes, keep_indices[sorted_indices[i]], + decoded_boxes, keep_indices[sorted_indices[i]], keep_indices[sorted_indices[j]]); if (intersection_over_union > intersection_over_union_threshold) { @@ -573,14 +567,18 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, TF_LITE_ENSURE(context, num_detections_per_class > 0); // For each class, perform non-max suppression. - float* class_scores = op_data->score_buffer; - int* box_indices_after_regular_non_max_suppression = op_data->buffer; - float* scores_after_regular_non_max_suppression = - op_data->scores_after_regular_non_max_suppression; + float* class_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->score_buffer_idx)); + int* box_indices_after_regular_non_max_suppression = reinterpret_cast( + context->GetScratchBuffer(context, op_data->buffer_idx)); + float* scores_after_regular_non_max_suppression = reinterpret_cast( + context->GetScratchBuffer(context, op_data->scores_after_regular_non_max_suppression_idx)); int size_of_sorted_indices = 0; - int* sorted_indices = op_data->sorted_indices; - float* sorted_values = op_data->sorted_values; + int* sorted_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_indices_idx)); + float* sorted_values = reinterpret_cast( + context->GetScratchBuffer(context, op_data->sorted_values_idx)); for (int col = 0; col < num_classes; col++) { for (int row = 0; row < num_boxes; row++) { @@ -590,7 +588,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, } // Perform non-maximal suppression on single class int selected_size = 0; - int* selected = op_data->selected; + int* selected = reinterpret_cast( + context->GetScratchBuffer(context, op_data->selected_idx)); TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( context, node, op_data, class_scores, selected, &selected_size, num_detections_per_class)); @@ -639,8 +638,10 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, const float selected_score = scores_after_regular_non_max_suppression[output_box_index]; // detection_boxes + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); ReInterpretTensor(detection_boxes)[output_box_index] = - reinterpret_cast(op_data->decoded_boxes)[anchor_index]; + reinterpret_cast(decoded_boxes)[anchor_index]; // detection_classes tflite::micro::GetTensorData(detection_classes)[output_box_index] = class_index; // detection_scores @@ -695,8 +696,11 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); const int num_categories_per_anchor = std::min(max_categories_per_anchor, num_classes); - float* max_scores = op_data->score_buffer; - int* sorted_class_indices = op_data->buffer; + float* max_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->score_buffer_idx)); + int* sorted_class_indices = reinterpret_cast( + context->GetScratchBuffer(context, op_data->buffer_idx)); + for (int row = 0; row < num_boxes; row++) { const float* box_scores = scores + row * num_classes_with_background + label_offset; @@ -708,7 +712,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, // Perform non-maximal suppression on max scores int selected_size = 0; - int* selected = op_data->selected; + int* selected = reinterpret_cast( + context->GetScratchBuffer(context, op_data->selected_idx)); TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( context, node, op_data, max_scores, selected, &selected_size, op_data->max_detections)); @@ -727,8 +732,10 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, int box_offset = num_categories_per_anchor * output_box_index + col; // detection_boxes + float* decoded_boxes = reinterpret_cast( + context->GetScratchBuffer(context, op_data->decoded_boxes_idx)); ReInterpretTensor(detection_boxes)[box_offset] = - reinterpret_cast(op_data->decoded_boxes)[selected_index]; + reinterpret_cast(decoded_boxes)[selected_index]; // detection_classes tflite::micro::GetTensorData(detection_classes)[box_offset] = class_indices[col]; @@ -782,7 +789,8 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, const float* scores; switch (input_class_predictions->type) { case kTfLiteUInt8: { - float* temporary_scores = op_data->scores; + float* temporary_scores = reinterpret_cast( + context->GetScratchBuffer(context, op_data->scores_idx)); DequantizeClassPredictions(input_class_predictions, num_boxes, num_classes_with_background, temporary_scores, op_data); @@ -806,32 +814,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, (kBatchSize == 1)); - - // Set up scratch buffers - void *raw; auto* op_data = static_cast(node->user_data); - raw = context->GetScratchBuffer(context, op_data->active_candidate_idx); - op_data->active_box_candidate = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->decoded_boxes_idx); - op_data->decoded_boxes = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->scores_idx); - op_data->scores = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->score_buffer_idx); - op_data->score_buffer = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->keep_scores_idx); - op_data->keep_scores = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->scores_after_regular_non_max_suppression_idx); - op_data->scores_after_regular_non_max_suppression = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->sorted_values_idx); - op_data->sorted_values = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->keep_indices_idx); - op_data->keep_indices = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->sorted_indices_idx); - op_data->sorted_indices = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->buffer_idx); - op_data->buffer = reinterpret_cast(raw); - raw = context->GetScratchBuffer(context, op_data->selected_idx); - op_data->selected = reinterpret_cast(raw); // These two functions correspond to two blocks in the Object Detection model. // In future, we would like to break the custom op in two blocks, which is