Merge pull request #40274 from mansnils:detection_postprocess
PiperOrigin-RevId: 333329643 Change-Id: I2c30b99dcc738970ff37c6fa08832d7cd93980e8
This commit is contained in:
commit
87b784f413
@ -36,6 +36,7 @@ cc_library(
|
||||
"comparisons.cc",
|
||||
"concatenation.cc",
|
||||
"dequantize.cc",
|
||||
"detection_postprocess.cc",
|
||||
"elementwise.cc",
|
||||
"ethosu.cc",
|
||||
"floor.cc",
|
||||
@ -115,6 +116,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"//tensorflow/lite/micro:memory_helpers",
|
||||
"//tensorflow/lite/micro:micro_utils",
|
||||
"@flatbuffers",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
":xtensa_hifimini": [
|
||||
@ -166,6 +168,20 @@ tflite_micro_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "detection_postprocess_test",
|
||||
srcs = [
|
||||
"detection_postprocess_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":kernel_runner",
|
||||
":micro_ops",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "fully_connected_test",
|
||||
srcs = [
|
||||
|
868
tensorflow/lite/micro/kernels/detection_postprocess.cc
Normal file
868
tensorflow/lite/micro/kernels/detection_postprocess.cc
Normal file
@ -0,0 +1,868 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace detection_postprocess {
|
||||
|
||||
/**
|
||||
* This version of detection_postprocess is specific to TFLite Micro. It
|
||||
* contains the following differences between the TFLite version:
|
||||
*
|
||||
* 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API.
|
||||
* 2.) Output dimensions - the TFLite version determines output size
|
||||
* and resizes the output tensor. Micro runtime does not support tensor
|
||||
* resizing. However if output dimensions are undefined TFLu memory API is
|
||||
* used to allocate the new dimensions.
|
||||
*/
|
||||
|
||||
// Input tensors
|
||||
constexpr int kInputTensorBoxEncodings = 0;
|
||||
constexpr int kInputTensorClassPredictions = 1;
|
||||
constexpr int kInputTensorAnchors = 2;
|
||||
|
||||
// Output tensors
|
||||
constexpr int kOutputTensorDetectionBoxes = 0;
|
||||
constexpr int kOutputTensorDetectionClasses = 1;
|
||||
constexpr int kOutputTensorDetectionScores = 2;
|
||||
constexpr int kOutputTensorNumDetections = 3;
|
||||
|
||||
constexpr int kNumCoordBox = 4;
|
||||
constexpr int kBatchSize = 1;
|
||||
|
||||
constexpr int kNumDetectionsPerClass = 100;
|
||||
|
||||
// Object Detection model produces axis-aligned boxes in two formats:
|
||||
// BoxCorner represents the lower left corner (xmin, ymin) and
|
||||
// the upper right corner (xmax, ymax).
|
||||
// CenterSize represents the center (xcenter, ycenter), height and width.
|
||||
// BoxCornerEncoding and CenterSizeEncoding are related as follows:
|
||||
// ycenter = y / y_scale * anchor.h + anchor.y;
|
||||
// xcenter = x / x_scale * anchor.w + anchor.x;
|
||||
// half_h = 0.5*exp(h/ h_scale)) * anchor.h;
|
||||
// half_w = 0.5*exp(w / w_scale)) * anchor.w;
|
||||
// ymin = ycenter - half_h
|
||||
// ymax = ycenter + half_h
|
||||
// xmin = xcenter - half_w
|
||||
// xmax = xcenter + half_w
|
||||
struct BoxCornerEncoding {
|
||||
float ymin;
|
||||
float xmin;
|
||||
float ymax;
|
||||
float xmax;
|
||||
};
|
||||
|
||||
struct CenterSizeEncoding {
|
||||
float y;
|
||||
float x;
|
||||
float h;
|
||||
float w;
|
||||
};
|
||||
// We make sure that the memory allocations are contiguous with static assert.
|
||||
static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
|
||||
"Size of BoxCornerEncoding is 4 float values");
|
||||
static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
|
||||
"Size of CenterSizeEncoding is 4 float values");
|
||||
|
||||
struct OpData {
|
||||
int max_detections;
|
||||
int max_classes_per_detection; // Fast Non-Max-Suppression
|
||||
int detections_per_class; // Regular Non-Max-Suppression
|
||||
float non_max_suppression_score_threshold;
|
||||
float intersection_over_union_threshold;
|
||||
int num_classes;
|
||||
bool use_regular_non_max_suppression;
|
||||
CenterSizeEncoding scale_values;
|
||||
|
||||
// Scratch buffers indexes
|
||||
int active_candidate_idx;
|
||||
int decoded_boxes_idx;
|
||||
int scores_idx;
|
||||
int score_buffer_idx;
|
||||
int keep_scores_idx;
|
||||
int scores_after_regular_non_max_suppression_idx;
|
||||
int sorted_values_idx;
|
||||
int keep_indices_idx;
|
||||
int sorted_indices_idx;
|
||||
int buffer_idx;
|
||||
int selected_idx;
|
||||
|
||||
// Cached tensor scale and zero point values for quantized operations
|
||||
TfLiteQuantizationParams input_box_encodings;
|
||||
TfLiteQuantizationParams input_class_predictions;
|
||||
TfLiteQuantizationParams input_anchors;
|
||||
};
|
||||
|
||||
TfLiteStatus AllocateOutDimensions(TfLiteContext* context,
|
||||
TfLiteIntArray** dims, int x, int y = 0,
|
||||
int z = 0) {
|
||||
int size = 1;
|
||||
|
||||
size = size * x;
|
||||
size = (y > 0) ? size * y : size;
|
||||
size = (z > 0) ? size * z : size;
|
||||
|
||||
*dims = reinterpret_cast<TfLiteIntArray*>(context->AllocatePersistentBuffer(
|
||||
context, TfLiteIntArrayGetSizeInBytes(size)));
|
||||
|
||||
(*dims)->size = size;
|
||||
(*dims)->data[0] = x;
|
||||
if (y > 0) {
|
||||
(*dims)->data[1] = y;
|
||||
}
|
||||
if (z > 0) {
|
||||
(*dims)->data[2] = z;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
OpData* op_data = nullptr;
|
||||
|
||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
op_data = reinterpret_cast<OpData*>(
|
||||
context->AllocatePersistentBuffer(context, sizeof(OpData)));
|
||||
|
||||
op_data->max_detections = m["max_detections"].AsInt32();
|
||||
op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
|
||||
if (m["detections_per_class"].IsNull())
|
||||
op_data->detections_per_class = kNumDetectionsPerClass;
|
||||
else
|
||||
op_data->detections_per_class = m["detections_per_class"].AsInt32();
|
||||
if (m["use_regular_nms"].IsNull())
|
||||
op_data->use_regular_non_max_suppression = false;
|
||||
else
|
||||
op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
|
||||
|
||||
op_data->non_max_suppression_score_threshold =
|
||||
m["nms_score_threshold"].AsFloat();
|
||||
op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
|
||||
op_data->num_classes = m["num_classes"].AsInt32();
|
||||
op_data->scale_values.y = m["y_scale"].AsFloat();
|
||||
op_data->scale_values.x = m["x_scale"].AsFloat();
|
||||
op_data->scale_values.h = m["h_scale"].AsFloat();
|
||||
op_data->scale_values.w = m["w_scale"].AsFloat();
|
||||
|
||||
return op_data;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
// Inputs: box_encodings, scores, anchors
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_anchors =
|
||||
GetInput(context, node, kInputTensorAnchors);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
|
||||
op_data->input_box_encodings.scale = input_box_encodings->params.scale;
|
||||
op_data->input_box_encodings.zero_point =
|
||||
input_box_encodings->params.zero_point;
|
||||
op_data->input_class_predictions.scale =
|
||||
input_class_predictions->params.scale;
|
||||
op_data->input_class_predictions.zero_point =
|
||||
input_class_predictions->params.zero_point;
|
||||
op_data->input_anchors.scale = input_anchors->params.scale;
|
||||
op_data->input_anchors.zero_point = input_anchors->params.zero_point;
|
||||
|
||||
// Scratch tensors
|
||||
context->RequestScratchBufferInArena(context, num_boxes,
|
||||
&op_data->active_candidate_idx);
|
||||
context->RequestScratchBufferInArena(context,
|
||||
num_boxes * kNumCoordBox * sizeof(float),
|
||||
&op_data->decoded_boxes_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context,
|
||||
input_class_predictions->dims->data[1] *
|
||||
input_class_predictions->dims->data[2] * sizeof(float),
|
||||
&op_data->scores_idx);
|
||||
|
||||
// Additional buffers
|
||||
context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
|
||||
&op_data->scores_idx);
|
||||
context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
|
||||
&op_data->keep_scores_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, op_data->max_detections * num_boxes * sizeof(float),
|
||||
&op_data->scores_after_regular_non_max_suppression_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, op_data->max_detections * num_boxes * sizeof(float),
|
||||
&op_data->sorted_values_idx);
|
||||
context->RequestScratchBufferInArena(context, num_boxes * sizeof(int),
|
||||
&op_data->keep_indices_idx);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, op_data->max_detections * num_boxes * sizeof(int),
|
||||
&op_data->sorted_indices_idx);
|
||||
int buffer_size = std::max(num_classes, op_data->max_detections);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, buffer_size * num_boxes * sizeof(int), &op_data->buffer_idx);
|
||||
buffer_size = std::min(num_boxes, op_data->max_detections);
|
||||
context->RequestScratchBufferInArena(
|
||||
context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
|
||||
|
||||
// number of detected boxes
|
||||
const int num_detected_boxes =
|
||||
op_data->max_detections * op_data->max_classes_per_detection;
|
||||
|
||||
// Outputs: detection_boxes, detection_scores, detection_classes,
|
||||
// num_detections
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||
|
||||
// Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
|
||||
TfLiteTensor* detection_boxes =
|
||||
GetOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
if (detection_boxes->dims->size == 0) {
|
||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(context, &detection_boxes->dims,
|
||||
1, num_detected_boxes, 4));
|
||||
}
|
||||
|
||||
// Output Tensor detection_classes: size is set to (1, num_detected_boxes)
|
||||
TfLiteTensor* detection_classes =
|
||||
GetOutput(context, node, kOutputTensorDetectionClasses);
|
||||
if (detection_classes->dims->size == 0) {
|
||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(
|
||||
context, &detection_classes->dims, 1, num_detected_boxes));
|
||||
}
|
||||
|
||||
// Output Tensor detection_scores: size is set to (1, num_detected_boxes)
|
||||
TfLiteTensor* detection_scores =
|
||||
GetOutput(context, node, kOutputTensorDetectionScores);
|
||||
if (detection_scores->dims->size == 0) {
|
||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(
|
||||
context, &detection_scores->dims, 1, num_detected_boxes));
|
||||
}
|
||||
|
||||
// Output Tensor num_detections: size is set to 1
|
||||
TfLiteTensor* num_detections =
|
||||
GetOutput(context, node, kOutputTensorNumDetections);
|
||||
if (num_detections->dims->size == 0) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
AllocateOutDimensions(context, &num_detections->dims, 1));
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
class Dequantizer {
|
||||
public:
|
||||
Dequantizer(int zero_point, float scale)
|
||||
: zero_point_(zero_point), scale_(scale) {}
|
||||
float operator()(uint8_t x) {
|
||||
return (static_cast<float>(x) - zero_point_) * scale_;
|
||||
}
|
||||
|
||||
private:
|
||||
int zero_point_;
|
||||
float scale_;
|
||||
};
|
||||
|
||||
void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings,
|
||||
int idx, float quant_zero_point, float quant_scale,
|
||||
int length_box_encoding,
|
||||
CenterSizeEncoding* box_centersize) {
|
||||
const uint8_t* boxes =
|
||||
tflite::micro::GetTensorData<uint8_t>(input_box_encodings) +
|
||||
length_box_encoding * idx;
|
||||
Dequantizer dequantize(quant_zero_point, quant_scale);
|
||||
// See definition of the KeyPointBoxCoder at
|
||||
// https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
|
||||
// The first four elements are the box coordinates, which is the same as the
|
||||
// FastRnnBoxCoder at
|
||||
// https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
|
||||
box_centersize->y = dequantize(boxes[0]);
|
||||
box_centersize->x = dequantize(boxes[1]);
|
||||
box_centersize->h = dequantize(boxes[2]);
|
||||
box_centersize->w = dequantize(boxes[3]);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T ReInterpretTensor(const TfLiteEvalTensor* tensor) {
|
||||
const float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
|
||||
return reinterpret_cast<T>(tensor_base);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T ReInterpretTensor(TfLiteEvalTensor* tensor) {
|
||||
float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
|
||||
return reinterpret_cast<T>(tensor_base);
|
||||
}
|
||||
|
||||
TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* op_data) {
|
||||
// Parse input tensor boxencodings
|
||||
const TfLiteEvalTensor* input_box_encodings =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||
TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
|
||||
const TfLiteEvalTensor* input_anchors =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorAnchors);
|
||||
|
||||
// Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
|
||||
CenterSizeEncoding box_centersize;
|
||||
CenterSizeEncoding scale_values = op_data->scale_values;
|
||||
CenterSizeEncoding anchor;
|
||||
for (int idx = 0; idx < num_boxes; ++idx) {
|
||||
switch (input_box_encodings->type) {
|
||||
// Quantized
|
||||
case kTfLiteUInt8:
|
||||
DequantizeBoxEncodings(
|
||||
input_box_encodings, idx,
|
||||
static_cast<float>(op_data->input_box_encodings.zero_point),
|
||||
static_cast<float>(op_data->input_box_encodings.scale),
|
||||
input_box_encodings->dims->data[2], &box_centersize);
|
||||
DequantizeBoxEncodings(
|
||||
input_anchors, idx,
|
||||
static_cast<float>(op_data->input_anchors.zero_point),
|
||||
static_cast<float>(op_data->input_anchors.scale), kNumCoordBox,
|
||||
&anchor);
|
||||
break;
|
||||
// Float
|
||||
case kTfLiteFloat32: {
|
||||
// Please see DequantizeBoxEncodings function for the support detail.
|
||||
const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
|
||||
const float* boxes = &(tflite::micro::GetTensorData<float>(
|
||||
input_box_encodings)[box_encoding_idx]);
|
||||
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
|
||||
anchor =
|
||||
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// Unsupported type.
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y;
|
||||
float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x;
|
||||
float half_h =
|
||||
0.5f * static_cast<float>(std::exp(box_centersize.h / scale_values.h)) *
|
||||
anchor.h;
|
||||
float half_w =
|
||||
0.5f * static_cast<float>(std::exp(box_centersize.w / scale_values.w)) *
|
||||
anchor.w;
|
||||
|
||||
float* decoded_boxes = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||
auto& box = reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[idx];
|
||||
box.ymin = ycenter - half_h;
|
||||
box.xmin = xcenter - half_w;
|
||||
box.ymax = ycenter + half_h;
|
||||
box.xmax = xcenter + half_w;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void DecreasingPartialArgSort(const float* values, int num_values,
|
||||
int num_to_sort, int* indices) {
|
||||
std::iota(indices, indices + num_values, 0);
|
||||
std::partial_sort(
|
||||
indices, indices + num_to_sort, indices + num_values,
|
||||
[&values](const int i, const int j) { return values[i] > values[j]; });
|
||||
}
|
||||
|
||||
void DecreasingPartialArgSort2(const float* values, int num_values,
|
||||
int num_to_sort, int* indices, int* ind) {
|
||||
std::iota(ind, ind + num_values, 0);
|
||||
std::partial_sort(
|
||||
ind, ind + num_to_sort, ind + num_values,
|
||||
[&values](const int i, const int j) { return values[i] > values[j]; });
|
||||
|
||||
std::iota(indices, indices + num_values, 0);
|
||||
|
||||
std::partial_sort(
|
||||
indices, indices + num_to_sort, indices + num_values,
|
||||
[&values](const int i, const int j) { return values[i] > values[j]; });
|
||||
}
|
||||
|
||||
int SelectDetectionsAboveScoreThreshold(const float* values, int size,
|
||||
const float threshold,
|
||||
float* keep_values, int* keep_indices) {
|
||||
int counter = 0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (values[i] >= threshold) {
|
||||
keep_values[counter++] = values[i];
|
||||
keep_indices[i] = i;
|
||||
}
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) {
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
// ymax>=ymin, xmax>=xmin
|
||||
auto& box = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
|
||||
if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
float ComputeIntersectionOverUnion(const float* decoded_boxes, const int i,
|
||||
const int j) {
|
||||
auto& box_i = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
|
||||
auto& box_j = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[j];
|
||||
const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
|
||||
const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
|
||||
if (area_i <= 0 || area_j <= 0) return 0.0;
|
||||
const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
|
||||
const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
|
||||
const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
|
||||
const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
|
||||
const float intersection_area =
|
||||
std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
|
||||
std::max<float>(intersection_xmax - intersection_xmin, 0.0);
|
||||
return intersection_area / (area_i + area_j - intersection_area);
|
||||
}
|
||||
|
||||
// NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
|
||||
// before selecting the highest scoring boxes (max_detections in number)
|
||||
// It assumes all boxes are good in beginning and sorts based on the scores.
|
||||
// If lower-scoring box has too much overlap with a higher-scoring box,
|
||||
// we get rid of the lower-scoring box.
|
||||
// Complexity is O(N^2) pairwise comparison between boxes
|
||||
TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
||||
TfLiteContext* context, TfLiteNode* node, OpData* op_data,
|
||||
const float* scores, int* selected, int* selected_size,
|
||||
int max_detections) {
|
||||
const TfLiteEvalTensor* input_box_encodings =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const float non_max_suppression_score_threshold =
|
||||
op_data->non_max_suppression_score_threshold;
|
||||
const float intersection_over_union_threshold =
|
||||
op_data->intersection_over_union_threshold;
|
||||
// Maximum detections should be positive.
|
||||
TF_LITE_ENSURE(context, (max_detections >= 0));
|
||||
// intersection_over_union_threshold should be positive
|
||||
// and should be less than 1.
|
||||
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
|
||||
(intersection_over_union_threshold <= 1.0f));
|
||||
// Validate boxes
|
||||
float* decoded_boxes = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||
|
||||
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
|
||||
|
||||
// threshold scores
|
||||
int* keep_indices = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->keep_indices_idx));
|
||||
float* keep_scores = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->keep_scores_idx));
|
||||
int num_scores_kept = SelectDetectionsAboveScoreThreshold(
|
||||
scores, num_boxes, non_max_suppression_score_threshold, keep_scores,
|
||||
keep_indices);
|
||||
int* sorted_indices = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->sorted_indices_idx));
|
||||
|
||||
DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept,
|
||||
sorted_indices);
|
||||
|
||||
const int num_boxes_kept = num_scores_kept;
|
||||
const int output_size = std::min(num_boxes_kept, max_detections);
|
||||
*selected_size = 0;
|
||||
|
||||
int num_active_candidate = num_boxes_kept;
|
||||
uint8_t* active_box_candidate = reinterpret_cast<uint8_t*>(
|
||||
context->GetScratchBuffer(context, op_data->active_candidate_idx));
|
||||
|
||||
for (int row = 0; row < num_boxes_kept; row++) {
|
||||
active_box_candidate[row] = 1;
|
||||
}
|
||||
for (int i = 0; i < num_boxes_kept; ++i) {
|
||||
if (num_active_candidate == 0 || *selected_size >= output_size) break;
|
||||
if (active_box_candidate[i] == 1) {
|
||||
selected[(*selected_size)++] = keep_indices[sorted_indices[i]];
|
||||
active_box_candidate[i] = 0;
|
||||
num_active_candidate--;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < num_boxes_kept; ++j) {
|
||||
if (active_box_candidate[j] == 1) {
|
||||
float intersection_over_union = ComputeIntersectionOverUnion(
|
||||
decoded_boxes, keep_indices[sorted_indices[i]],
|
||||
keep_indices[sorted_indices[j]]);
|
||||
|
||||
if (intersection_over_union > intersection_over_union_threshold) {
|
||||
active_box_candidate[j] = 0;
|
||||
num_active_candidate--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// This function implements a regular version of Non Maximal Suppression (NMS)
|
||||
// for multiple classes where
|
||||
// 1) we do NMS separately for each class across all anchors and
|
||||
// 2) keep only the highest anchor scores across all classes
|
||||
// 3) The worst runtime of the regular NMS is O(K*N^2)
|
||||
// where N is the number of anchors and K the number of
|
||||
// classes.
|
||||
TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
OpData* op_data,
|
||||
const float* scores) {
|
||||
const TfLiteEvalTensor* input_box_encodings =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteEvalTensor* input_class_predictions =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||
TfLiteEvalTensor* detection_boxes =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
|
||||
context, node, kOutputTensorDetectionClasses);
|
||||
TfLiteEvalTensor* detection_scores =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
|
||||
TfLiteEvalTensor* num_detections =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
|
||||
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
const int num_detections_per_class = op_data->detections_per_class;
|
||||
const int max_detections = op_data->max_detections;
|
||||
const int num_classes_with_background =
|
||||
input_class_predictions->dims->data[2];
|
||||
// The row index offset is 1 if background class is included and 0 otherwise.
|
||||
int label_offset = num_classes_with_background - num_classes;
|
||||
TF_LITE_ENSURE(context, num_detections_per_class > 0);
|
||||
|
||||
// For each class, perform non-max suppression.
|
||||
float* class_scores = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->score_buffer_idx));
|
||||
int* box_indices_after_regular_non_max_suppression = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->buffer_idx));
|
||||
float* scores_after_regular_non_max_suppression =
|
||||
reinterpret_cast<float*>(context->GetScratchBuffer(
|
||||
context, op_data->scores_after_regular_non_max_suppression_idx));
|
||||
|
||||
int size_of_sorted_indices = 0;
|
||||
int* sorted_indices = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->sorted_indices_idx));
|
||||
float* sorted_values = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->sorted_values_idx));
|
||||
|
||||
for (int col = 0; col < num_classes; col++) {
|
||||
for (int row = 0; row < num_boxes; row++) {
|
||||
// Get scores of boxes corresponding to all anchors for single class
|
||||
class_scores[row] =
|
||||
*(scores + row * num_classes_with_background + col + label_offset);
|
||||
}
|
||||
// Perform non-maximal suppression on single class
|
||||
int selected_size = 0;
|
||||
int* selected = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->selected_idx));
|
||||
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
|
||||
context, node, op_data, class_scores, selected, &selected_size,
|
||||
num_detections_per_class));
|
||||
// Add selected indices from non-max suppression of boxes in this class
|
||||
int output_index = size_of_sorted_indices;
|
||||
for (int i = 0; i < selected_size; i++) {
|
||||
int selected_index = selected[i];
|
||||
|
||||
box_indices_after_regular_non_max_suppression[output_index] =
|
||||
(selected_index * num_classes_with_background + col + label_offset);
|
||||
scores_after_regular_non_max_suppression[output_index] =
|
||||
class_scores[selected_index];
|
||||
output_index++;
|
||||
}
|
||||
// Sort the max scores among the selected indices
|
||||
// Get the indices for top scores
|
||||
int num_indices_to_sort = std::min(output_index, max_detections);
|
||||
DecreasingPartialArgSort(scores_after_regular_non_max_suppression,
|
||||
output_index, num_indices_to_sort, sorted_indices);
|
||||
|
||||
// Copy values to temporary vectors
|
||||
for (int row = 0; row < num_indices_to_sort; row++) {
|
||||
int temp = sorted_indices[row];
|
||||
sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
|
||||
sorted_values[row] = scores_after_regular_non_max_suppression[temp];
|
||||
}
|
||||
// Copy scores and indices from temporary vectors
|
||||
for (int row = 0; row < num_indices_to_sort; row++) {
|
||||
box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
|
||||
scores_after_regular_non_max_suppression[row] = sorted_values[row];
|
||||
}
|
||||
size_of_sorted_indices = num_indices_to_sort;
|
||||
}
|
||||
|
||||
// Allocate output tensors
|
||||
for (int output_box_index = 0; output_box_index < max_detections;
|
||||
output_box_index++) {
|
||||
if (output_box_index < size_of_sorted_indices) {
|
||||
const int anchor_index = floor(
|
||||
box_indices_after_regular_non_max_suppression[output_box_index] /
|
||||
num_classes_with_background);
|
||||
const int class_index =
|
||||
box_indices_after_regular_non_max_suppression[output_box_index] -
|
||||
anchor_index * num_classes_with_background - label_offset;
|
||||
const float selected_score =
|
||||
scores_after_regular_non_max_suppression[output_box_index];
|
||||
// detection_boxes
|
||||
float* decoded_boxes = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
|
||||
reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[anchor_index];
|
||||
// detection_classes
|
||||
tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
|
||||
class_index;
|
||||
// detection_scores
|
||||
tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
|
||||
selected_score;
|
||||
} else {
|
||||
ReInterpretTensor<BoxCornerEncoding*>(
|
||||
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
// detection_classes
|
||||
tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
|
||||
0.0f;
|
||||
// detection_scores
|
||||
tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
|
||||
0.0f;
|
||||
}
|
||||
}
|
||||
tflite::micro::GetTensorData<float>(num_detections)[0] =
|
||||
size_of_sorted_indices;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// This function implements a fast version of Non Maximal Suppression for
|
||||
// multiple classes where
|
||||
// 1) we keep the top-k scores for each anchor and
|
||||
// 2) during NMS, each anchor only uses the highest class score for sorting.
|
||||
// 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
|
||||
// instead of O(KN^2) where N is the number of anchors and K the number of
|
||||
// classes.
|
||||
TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
OpData* op_data,
|
||||
const float* scores) {
|
||||
const TfLiteEvalTensor* input_box_encodings =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteEvalTensor* input_class_predictions =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||
TfLiteEvalTensor* detection_boxes =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
|
||||
TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
|
||||
context, node, kOutputTensorDetectionClasses);
|
||||
TfLiteEvalTensor* detection_scores =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
|
||||
TfLiteEvalTensor* num_detections =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
|
||||
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
const int max_categories_per_anchor = op_data->max_classes_per_detection;
|
||||
const int num_classes_with_background =
|
||||
input_class_predictions->dims->data[2];
|
||||
|
||||
// The row index offset is 1 if background class is included and 0 otherwise.
|
||||
int label_offset = num_classes_with_background - num_classes;
|
||||
TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
|
||||
const int num_categories_per_anchor =
|
||||
std::min(max_categories_per_anchor, num_classes);
|
||||
float* max_scores = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->score_buffer_idx));
|
||||
int* sorted_class_indices = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->buffer_idx));
|
||||
|
||||
for (int row = 0; row < num_boxes; row++) {
|
||||
const float* box_scores =
|
||||
scores + row * num_classes_with_background + label_offset;
|
||||
int* class_indices = sorted_class_indices + row * num_classes;
|
||||
DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
|
||||
class_indices);
|
||||
max_scores[row] = box_scores[class_indices[0]];
|
||||
}
|
||||
|
||||
// Perform non-maximal suppression on max scores
|
||||
int selected_size = 0;
|
||||
int* selected = reinterpret_cast<int*>(
|
||||
context->GetScratchBuffer(context, op_data->selected_idx));
|
||||
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
|
||||
context, node, op_data, max_scores, selected, &selected_size,
|
||||
op_data->max_detections));
|
||||
|
||||
// Allocate output tensors
|
||||
int output_box_index = 0;
|
||||
|
||||
for (int i = 0; i < selected_size; i++) {
|
||||
int selected_index = selected[i];
|
||||
|
||||
const float* box_scores =
|
||||
scores + selected_index * num_classes_with_background + label_offset;
|
||||
const int* class_indices =
|
||||
sorted_class_indices + selected_index * num_classes;
|
||||
|
||||
for (int col = 0; col < num_categories_per_anchor; ++col) {
|
||||
int box_offset = num_categories_per_anchor * output_box_index + col;
|
||||
|
||||
// detection_boxes
|
||||
float* decoded_boxes = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
|
||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
|
||||
reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[selected_index];
|
||||
|
||||
// detection_classes
|
||||
tflite::micro::GetTensorData<float>(detection_classes)[box_offset] =
|
||||
class_indices[col];
|
||||
|
||||
// detection_scores
|
||||
tflite::micro::GetTensorData<float>(detection_scores)[box_offset] =
|
||||
box_scores[class_indices[col]];
|
||||
|
||||
output_box_index++;
|
||||
}
|
||||
}
|
||||
|
||||
tflite::micro::GetTensorData<float>(num_detections)[0] = output_box_index;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions,
|
||||
const int num_boxes,
|
||||
const int num_classes_with_background,
|
||||
float* scores, OpData* op_data) {
|
||||
float quant_zero_point =
|
||||
static_cast<float>(op_data->input_class_predictions.zero_point);
|
||||
float quant_scale =
|
||||
static_cast<float>(op_data->input_class_predictions.scale);
|
||||
Dequantizer dequantize(quant_zero_point, quant_scale);
|
||||
const uint8_t* scores_quant =
|
||||
tflite::micro::GetTensorData<uint8_t>(input_class_predictions);
|
||||
for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
|
||||
scores[idx] = dequantize(scores_quant[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
|
||||
TfLiteNode* node, OpData* op_data) {
|
||||
// Get the input tensors
|
||||
const TfLiteEvalTensor* input_box_encodings =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteEvalTensor* input_class_predictions =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
|
||||
kBatchSize);
|
||||
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
|
||||
const int num_classes_with_background =
|
||||
input_class_predictions->dims->data[2];
|
||||
|
||||
TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
|
||||
TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
|
||||
|
||||
const float* scores;
|
||||
switch (input_class_predictions->type) {
|
||||
case kTfLiteUInt8: {
|
||||
float* temporary_scores = reinterpret_cast<float*>(
|
||||
context->GetScratchBuffer(context, op_data->scores_idx));
|
||||
DequantizeClassPredictions(input_class_predictions, num_boxes,
|
||||
num_classes_with_background, temporary_scores,
|
||||
op_data);
|
||||
scores = temporary_scores;
|
||||
} break;
|
||||
case kTfLiteFloat32:
|
||||
scores = tflite::micro::GetTensorData<float>(input_class_predictions);
|
||||
break;
|
||||
default:
|
||||
// Unsupported type.
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (op_data->use_regular_non_max_suppression)
|
||||
TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
|
||||
context, node, op_data, scores));
|
||||
else
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
// These two functions correspond to two blocks in the Object Detection model.
|
||||
// In future, we would like to break the custom op in two blocks, which is
|
||||
// currently not feasible because we would like to input quantized inputs
|
||||
// and do all calculations in float. Mixed quantized/float calculations are
|
||||
// currently not supported in TFLite.
|
||||
|
||||
// This fills in temporary decoded_boxes
|
||||
// by transforming input_box_encodings and input_anchors from
|
||||
// CenterSizeEncodings to BoxCornerEncoding
|
||||
TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
|
||||
|
||||
// This fills in the output tensors
|
||||
// by choosing effective set of decoded boxes
|
||||
// based on Non Maximal Suppression, i.e. selecting
|
||||
// highest scoring non-overlapping boxes.
|
||||
TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace detection_postprocess
|
||||
|
||||
TfLiteRegistration Register_DETECTION_POSTPROCESS() {
|
||||
return {/*init=*/detection_postprocess::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/detection_postprocess::Prepare,
|
||||
/*invoke=*/detection_postprocess::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
496
tensorflow/lite/micro/kernels/detection_postprocess_test.cc
Normal file
496
tensorflow/lite/micro/kernels/detection_postprocess_test.cc
Normal file
@ -0,0 +1,496 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
// Common inputs and outputs.
|
||||
|
||||
static const int kInputShape1[] = {3, 1, 6, 4};
|
||||
static const int kInputShape2[] = {3, 1, 6, 3};
|
||||
static const int kInputShape3[] = {2, 6, 4};
|
||||
static const int kOutputShape1[] = {3, 1, 3, 4};
|
||||
static const int kOutputShape2[] = {2, 1, 3};
|
||||
static const int kOutputShape3[] = {2, 1, 3};
|
||||
static const int kOutputShape4[] = {1, 1};
|
||||
|
||||
// six boxes in center-size encoding
|
||||
static const float kInputData1[] = {
|
||||
0.0, 0.0, 0.0, 0.0, // box #1
|
||||
0.0, 1.0, 0.0, 0.0, // box #2
|
||||
0.0, -1.0, 0.0, 0.0, // box #3
|
||||
0.0, 0.0, 0.0, 0.0, // box #4
|
||||
0.0, 1.0, 0.0, 0.0, // box #5
|
||||
0.0, 0.0, 0.0, 0.0 // box #6
|
||||
};
|
||||
|
||||
// class scores - two classes with background
|
||||
static const float kInputData2[] = {0., .9, .8, 0., .75, .72, 0., .6, .5,
|
||||
0., .93, .95, 0., .5, .4, 0., .3, .2};
|
||||
|
||||
// six anchors in center-size encoding
|
||||
static const float kInputData3[] = {
|
||||
0.5, 0.5, 1.0, 1.0, // anchor #1
|
||||
0.5, 0.5, 1.0, 1.0, // anchor #2
|
||||
0.5, 0.5, 1.0, 1.0, // anchor #3
|
||||
0.5, 10.5, 1.0, 1.0, // anchor #4
|
||||
0.5, 10.5, 1.0, 1.0, // anchor #5
|
||||
0.5, 100.5, 1.0, 1.0 // anchor #6
|
||||
};
|
||||
// Same boxes in box-corner encoding:
|
||||
// { 0.0, 0.0, 1.0, 1.0,
|
||||
// 0.0, 0.1, 1.0, 1.1,
|
||||
// 0.0, -0.1, 1.0, 0.9,
|
||||
// 0.0, 10.0, 1.0, 11.0,
|
||||
// 0.0, 10.1, 1.0, 11.1,
|
||||
// 0.0, 100.0, 1.0, 101.0}
|
||||
|
||||
static const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 0.0,
|
||||
1.0, 1.0, 0.0, 100.0, 1.0, 101.0};
|
||||
static const float kGolden2[] = {1, 0, 0};
|
||||
static const float kGolden3[] = {0.95, 0.9, 0.3};
|
||||
static const float kGolden4[] = {3.0};
|
||||
|
||||
void TestDetectionPostprocess(
|
||||
const int* input_dims_data1, const float* input_data1,
|
||||
const int* input_dims_data2, const float* input_data2,
|
||||
const int* input_dims_data3, const float* input_data3,
|
||||
const int* output_dims_data1, float* output_data1,
|
||||
const int* output_dims_data2, float* output_data2,
|
||||
const int* output_dims_data3, float* output_data3,
|
||||
const int* output_dims_data4, float* output_data4, const float* golden1,
|
||||
const float* golden2, const float* golden3, const float* golden4,
|
||||
const float tolerance, bool use_regular_nms,
|
||||
uint8_t* input_data_quantized1 = nullptr,
|
||||
uint8_t* input_data_quantized2 = nullptr,
|
||||
uint8_t* input_data_quantized3 = nullptr, const float input_min1 = 0,
|
||||
const float input_max1 = 0, const float input_min2 = 0,
|
||||
const float input_max2 = 0, const float input_min3 = 0,
|
||||
const float input_max3 = 0) {
|
||||
TfLiteIntArray* input_dims1 = IntArrayFromInts(input_dims_data1);
|
||||
TfLiteIntArray* input_dims2 = IntArrayFromInts(input_dims_data2);
|
||||
TfLiteIntArray* input_dims3 = IntArrayFromInts(input_dims_data3);
|
||||
TfLiteIntArray* output_dims1 = nullptr;
|
||||
TfLiteIntArray* output_dims2 = nullptr;
|
||||
TfLiteIntArray* output_dims3 = nullptr;
|
||||
TfLiteIntArray* output_dims4 = nullptr;
|
||||
|
||||
const int zero_length_int_array_data[] = {0};
|
||||
TfLiteIntArray* zero_length_int_array =
|
||||
IntArrayFromInts(zero_length_int_array_data);
|
||||
|
||||
output_dims1 = output_dims_data1 == nullptr
|
||||
? const_cast<TfLiteIntArray*>(zero_length_int_array)
|
||||
: IntArrayFromInts(output_dims_data1);
|
||||
output_dims2 = output_dims_data2 == nullptr
|
||||
? const_cast<TfLiteIntArray*>(zero_length_int_array)
|
||||
: IntArrayFromInts(output_dims_data2);
|
||||
output_dims3 = output_dims_data3 == nullptr
|
||||
? const_cast<TfLiteIntArray*>(zero_length_int_array)
|
||||
: IntArrayFromInts(output_dims_data3);
|
||||
output_dims4 = output_dims_data4 == nullptr
|
||||
? const_cast<TfLiteIntArray*>(zero_length_int_array)
|
||||
: IntArrayFromInts(output_dims_data4);
|
||||
|
||||
constexpr int inputs_size = 3;
|
||||
constexpr int outputs_size = 4;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
|
||||
TfLiteTensor tensors[tensors_size];
|
||||
if (input_min1 != 0 || input_max1 != 0 || input_min2 != 0 ||
|
||||
input_max2 != 0 || input_min3 != 0 || input_max3 != 0) {
|
||||
const float input_scale1 = ScaleFromMinMax<uint8_t>(input_min1, input_max1);
|
||||
const int input_zero_point1 =
|
||||
ZeroPointFromMinMax<uint8_t>(input_min1, input_max1);
|
||||
const float input_scale2 = ScaleFromMinMax<uint8_t>(input_min2, input_max2);
|
||||
const int input_zero_point2 =
|
||||
ZeroPointFromMinMax<uint8_t>(input_min2, input_max2);
|
||||
const float input_scale3 = ScaleFromMinMax<uint8_t>(input_min3, input_max3);
|
||||
const int input_zero_point3 =
|
||||
ZeroPointFromMinMax<uint8_t>(input_min3, input_max3);
|
||||
|
||||
tensors[0] =
|
||||
CreateQuantizedTensor(input_data1, input_data_quantized1, input_dims1,
|
||||
input_scale1, input_zero_point1);
|
||||
tensors[1] =
|
||||
CreateQuantizedTensor(input_data2, input_data_quantized2, input_dims2,
|
||||
input_scale2, input_zero_point2);
|
||||
tensors[2] =
|
||||
CreateQuantizedTensor(input_data3, input_data_quantized3, input_dims3,
|
||||
input_scale3, input_zero_point3);
|
||||
} else {
|
||||
tensors[0] = CreateFloatTensor(input_data1, input_dims1);
|
||||
tensors[1] = CreateFloatTensor(input_data2, input_dims2);
|
||||
tensors[2] = CreateFloatTensor(input_data3, input_dims3);
|
||||
}
|
||||
tensors[3] = CreateFloatTensor(output_data1, output_dims1);
|
||||
tensors[4] = CreateFloatTensor(output_data2, output_dims2);
|
||||
tensors[5] = CreateFloatTensor(output_data3, output_dims3);
|
||||
tensors[6] = CreateFloatTensor(output_data4, output_dims4);
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.Int("max_detections", 3);
|
||||
fbb.Int("max_classes_per_detection", 1);
|
||||
fbb.Int("detections_per_class", 1);
|
||||
fbb.Bool("use_regular_nms", use_regular_nms);
|
||||
fbb.Float("nms_score_threshold", 0.0);
|
||||
fbb.Float("nms_iou_threshold", 0.5);
|
||||
fbb.Int("num_classes", 2);
|
||||
fbb.Float("y_scale", 10.0);
|
||||
fbb.Float("x_scale", 10.0);
|
||||
fbb.Float("h_scale", 5.0);
|
||||
fbb.Float("w_scale", 5.0);
|
||||
});
|
||||
fbb.Finish();
|
||||
|
||||
const TfLiteRegistration& registration =
|
||||
tflite::ops::micro::Register_DETECTION_POSTPROCESS();
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, ®istration);
|
||||
|
||||
int inputs_array_data[] = {3, 0, 1, 2};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {4, 3, 4, 5, 6};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array, nullptr, micro_test::reporter);
|
||||
|
||||
const char* init_data = reinterpret_cast<const char*>(fbb.GetBuffer().data());
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, runner.InitAndPrepare(init_data, fbb.GetBuffer().size()));
|
||||
|
||||
// Output dimensions should not be undefined after Prepare
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[3].dims);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[4].dims);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[5].dims);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, tensors[6].dims);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
const int output_elements_count1 = tensors[3].dims->size;
|
||||
const int output_elements_count2 = tensors[4].dims->size;
|
||||
const int output_elements_count3 = tensors[5].dims->size;
|
||||
const int output_elements_count4 = tensors[6].dims->size;
|
||||
|
||||
for (int i = 0; i < output_elements_count1; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(golden1[i], output_data1[i], tolerance);
|
||||
}
|
||||
for (int i = 0; i < output_elements_count2; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(golden2[i], output_data2[i], tolerance);
|
||||
}
|
||||
for (int i = 0; i < output_elements_count3; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(golden3[i], output_data3[i], tolerance);
|
||||
}
|
||||
for (int i = 0; i < output_elements_count4; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(golden4[i], output_data4[i], tolerance);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMS) {
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1,
|
||||
tflite::testing::kGolden2, tflite::testing::kGolden3,
|
||||
tflite::testing::kGolden4,
|
||||
/* tolerance */ 0, /* Use regular NMS: */ false);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedFastNMS) {
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
const int kInputElements1 = tflite::testing::kInputShape1[1] *
|
||||
tflite::testing::kInputShape1[2] *
|
||||
tflite::testing::kInputShape1[3];
|
||||
const int kInputElements2 = tflite::testing::kInputShape2[1] *
|
||||
tflite::testing::kInputShape2[2] *
|
||||
tflite::testing::kInputShape2[3];
|
||||
const int kInputElements3 =
|
||||
tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2];
|
||||
|
||||
uint8_t input_data_quantized1[kInputElements1 + 10];
|
||||
uint8_t input_data_quantized2[kInputElements2 + 10];
|
||||
uint8_t input_data_quantized3[kInputElements3 + 10];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1,
|
||||
tflite::testing::kGolden2, tflite::testing::kGolden3,
|
||||
tflite::testing::kGolden4,
|
||||
/* tolerance */ 3e-1, /* Use regular NMS: */ false, input_data_quantized1,
|
||||
input_data_quantized2, input_data_quantized3,
|
||||
/* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0,
|
||||
/* input3 min/max */ 0.0, 100.5);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(DetectionPostprocessFloatRegularNMS) {
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
|
||||
1.0, 11.0, 0.0, 0.0, 0.0, 0.0};
|
||||
const float kGolden3[] = {0.95, 0.9, 0.0};
|
||||
const float kGolden4[] = {2.0};
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, kGolden1,
|
||||
tflite::testing::kGolden2, kGolden3, kGolden4,
|
||||
/* tolerance */ 1e-1, /* Use regular NMS: */ true);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(DetectionPostprocessQuantizedRegularNMS) {
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
const int kInputElements1 = tflite::testing::kInputShape1[1] *
|
||||
tflite::testing::kInputShape1[2] *
|
||||
tflite::testing::kInputShape1[3];
|
||||
const int kInputElements2 = tflite::testing::kInputShape2[1] *
|
||||
tflite::testing::kInputShape2[2] *
|
||||
tflite::testing::kInputShape2[3];
|
||||
const int kInputElements3 =
|
||||
tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2];
|
||||
|
||||
uint8_t input_data_quantized1[kInputElements1 + 10];
|
||||
uint8_t input_data_quantized2[kInputElements2 + 10];
|
||||
uint8_t input_data_quantized3[kInputElements3 + 10];
|
||||
|
||||
const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
|
||||
1.0, 11.0, 0.0, 0.0, 0.0, 0.0};
|
||||
const float kGolden3[] = {0.95, 0.9, 0.0};
|
||||
const float kGolden4[] = {2.0};
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, kGolden1,
|
||||
tflite::testing::kGolden2, kGolden3, kGolden4,
|
||||
/* tolerance */ 3e-1, /* Use regular NMS: */ true, input_data_quantized1,
|
||||
input_data_quantized2, input_data_quantized3,
|
||||
/* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0,
|
||||
/* input3 min/max */ 0.0, 100.5);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(
|
||||
DetectionPostprocessFloatFastNMSwithNoBackgroundClassAndKeypoints) {
|
||||
const int kInputShape1[] = {3, 1, 6, 5};
|
||||
const int kInputShape2[] = {3, 1, 6, 2};
|
||||
|
||||
// six boxes in center-size encoding
|
||||
const float kInputData1[] = {
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #1
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #2
|
||||
0.0, -1.0, 0.0, 0.0, 1.0, // box #3
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #4
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #5
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #6
|
||||
};
|
||||
|
||||
// class scores - two classes without background
|
||||
const float kInputData2[] = {.9, .8, .75, .72, .6, .5,
|
||||
.93, .95, .5, .4, .3, .2};
|
||||
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
kInputShape1, kInputData1, kInputShape2, kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1,
|
||||
tflite::testing::kGolden2, tflite::testing::kGolden3,
|
||||
tflite::testing::kGolden4,
|
||||
/* tolerance */ 0, /* Use regular NMS: */ false);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(
|
||||
DetectionPostprocessFloatRegularNMSwithNoBackgroundClassAndKeypoints) {
|
||||
const int kInputShape2[] = {3, 1, 6, 2};
|
||||
|
||||
// class scores - two classes without background
|
||||
const float kInputData2[] = {.9, .8, .75, .72, .6, .5,
|
||||
.93, .95, .5, .4, .3, .2};
|
||||
|
||||
const float kGolden1[] = {0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
|
||||
1.0, 11.0, 0.0, 0.0, 0.0, 0.0};
|
||||
const float kGolden3[] = {0.95, 0.9, 0.0};
|
||||
const float kGolden4[] = {2.0};
|
||||
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1, kInputShape2,
|
||||
kInputData2, tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, kGolden1,
|
||||
tflite::testing::kGolden2, kGolden3, kGolden4,
|
||||
/* tolerance */ 1e-1, /* Use regular NMS: */ true);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(
|
||||
DetectionPostprocessFloatFastNMSWithBackgroundClassAndKeypoints) {
|
||||
const int kInputShape1[] = {3, 1, 6, 5};
|
||||
|
||||
// six boxes in center-size encoding
|
||||
const float kInputData1[] = {
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #1
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #2
|
||||
0.0, -1.0, 0.0, 0.0, 1.0, // box #3
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #4
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #5
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #6
|
||||
};
|
||||
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
kInputShape1, kInputData1, tflite::testing::kInputShape2,
|
||||
tflite::testing::kInputData2, tflite::testing::kInputShape3,
|
||||
tflite::testing::kInputData3, tflite::testing::kOutputShape1,
|
||||
output_data1, tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1,
|
||||
tflite::testing::kGolden2, tflite::testing::kGolden3,
|
||||
tflite::testing::kGolden4,
|
||||
/* tolerance */ 0, /* Use regular NMS: */ false);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(
|
||||
DetectionPostprocessQuantizedFastNMSwithNoBackgroundClassAndKeypoints) {
|
||||
const int kInputShape1[] = {3, 1, 6, 5};
|
||||
const int kInputShape2[] = {3, 1, 6, 2};
|
||||
|
||||
// six boxes in center-size encoding
|
||||
const float kInputData1[] = {
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #1
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #2
|
||||
0.0, -1.0, 0.0, 0.0, 1.0, // box #3
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #4
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, // box #5
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // box #6
|
||||
};
|
||||
|
||||
// class scores - two classes without background
|
||||
const float kInputData2[] = {.9, .8, .75, .72, .6, .5,
|
||||
.93, .95, .5, .4, .3, .2};
|
||||
|
||||
const int kInputElements1 = tflite::testing::kInputShape1[1] *
|
||||
tflite::testing::kInputShape1[2] *
|
||||
tflite::testing::kInputShape1[3];
|
||||
const int kInputElements2 = tflite::testing::kInputShape2[1] *
|
||||
tflite::testing::kInputShape2[2] *
|
||||
tflite::testing::kInputShape2[3];
|
||||
const int kInputElements3 =
|
||||
tflite::testing::kInputShape3[1] * tflite::testing::kInputShape3[2];
|
||||
|
||||
uint8_t input_data_quantized1[kInputElements1 + 10];
|
||||
uint8_t input_data_quantized2[kInputElements2 + 10];
|
||||
uint8_t input_data_quantized3[kInputElements3 + 10];
|
||||
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
kInputShape1, kInputData1, kInputShape2, kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3,
|
||||
tflite::testing::kOutputShape1, output_data1,
|
||||
tflite::testing::kOutputShape2, output_data2,
|
||||
tflite::testing::kOutputShape3, output_data3,
|
||||
tflite::testing::kOutputShape4, output_data4, tflite::testing::kGolden1,
|
||||
tflite::testing::kGolden2, tflite::testing::kGolden3,
|
||||
tflite::testing::kGolden4,
|
||||
/* tolerance */ 3e-1, /* Use regular NMS: */ false, input_data_quantized1,
|
||||
input_data_quantized2, input_data_quantized3,
|
||||
/* input1 min/max*/ -1.0, 1.0, /* input2 min/max */ 0.0, 1.0,
|
||||
/* input3 min/max */ 0.0, 100.5);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSUndefinedOutputDimensions) {
|
||||
float output_data1[12];
|
||||
float output_data2[3];
|
||||
float output_data3[3];
|
||||
float output_data4[1];
|
||||
|
||||
tflite::testing::TestDetectionPostprocess(
|
||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
||||
tflite::testing::kInputShape3, tflite::testing::kInputData3, nullptr,
|
||||
output_data1, nullptr, output_data2, nullptr, output_data3, nullptr,
|
||||
output_data4, tflite::testing::kGolden1, tflite::testing::kGolden2,
|
||||
tflite::testing::kGolden3, tflite::testing::kGolden4,
|
||||
/* tolerance */ 0, /* Use regular NMS: */ false);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -23,16 +23,16 @@ constexpr size_t kBufferAlignment = 16;
|
||||
} // namespace
|
||||
|
||||
// TODO(b/161841696): Consider moving away from global arena buffers:
|
||||
constexpr int KernelRunner::kNumScratchBuffers_;
|
||||
constexpr int KernelRunner::kKernelRunnerBufferSize_;
|
||||
uint8_t KernelRunner::kKernelRunnerBuffer_[];
|
||||
constexpr int KernelRunner::kNumScratchBuffers;
|
||||
constexpr int KernelRunner::kKernelRunnerBufferSize;
|
||||
uint8_t KernelRunner::kernel_runner_buffer_[];
|
||||
|
||||
KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
|
||||
void* builtin_data, ErrorReporter* error_reporter)
|
||||
: allocator_(SimpleMemoryAllocator::Create(
|
||||
error_reporter, kKernelRunnerBuffer_, kKernelRunnerBufferSize_)),
|
||||
error_reporter, kernel_runner_buffer_, kKernelRunnerBufferSize)),
|
||||
registration_(registration),
|
||||
tensors_(tensors),
|
||||
error_reporter_(error_reporter) {
|
||||
@ -52,9 +52,10 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
|
||||
node_.builtin_data = builtin_data;
|
||||
}
|
||||
|
||||
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) {
|
||||
TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
|
||||
size_t length) {
|
||||
if (registration_.init) {
|
||||
node_.user_data = registration_.init(&context_, init_data, /*length=*/0);
|
||||
node_.user_data = registration_.init(&context_, init_data, length);
|
||||
}
|
||||
if (registration_.prepare) {
|
||||
TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
|
||||
@ -117,11 +118,11 @@ TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
|
||||
if (runner->scratch_buffer_count_ == kNumScratchBuffers) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
runner->error_reporter_,
|
||||
"Exceeded the maximum number of scratch tensors allowed (%d).",
|
||||
kNumScratchBuffers_);
|
||||
kNumScratchBuffers);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@ -142,7 +143,7 @@ void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
|
||||
KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
|
||||
TFLITE_DCHECK(runner != nullptr);
|
||||
|
||||
TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
|
||||
TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers);
|
||||
if (buffer_index >= runner->scratch_buffer_count_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -39,7 +39,8 @@ class KernelRunner {
|
||||
// Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
|
||||
// exceptions will be reported through the error_reporter and returned as a
|
||||
// status code here.
|
||||
TfLiteStatus InitAndPrepare(const char* init_data = nullptr);
|
||||
TfLiteStatus InitAndPrepare(const char* init_data = nullptr,
|
||||
size_t length = 0);
|
||||
|
||||
// Calls init, prepare, and invoke on a given TfLiteRegistration pointer.
|
||||
// After successful invoke, results will be available in the output tensor as
|
||||
@ -60,10 +61,10 @@ class KernelRunner {
|
||||
...);
|
||||
|
||||
private:
|
||||
static constexpr int kNumScratchBuffers_ = 5;
|
||||
static constexpr int kNumScratchBuffers = 12;
|
||||
|
||||
static constexpr int kKernelRunnerBufferSize_ = 10000;
|
||||
static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_];
|
||||
static constexpr int kKernelRunnerBufferSize = 10000;
|
||||
static uint8_t kernel_runner_buffer_[kKernelRunnerBufferSize];
|
||||
|
||||
SimpleMemoryAllocator* allocator_ = nullptr;
|
||||
const TfLiteRegistration& registration_;
|
||||
@ -74,7 +75,7 @@ class KernelRunner {
|
||||
TfLiteNode node_ = {};
|
||||
|
||||
int scratch_buffer_count_ = 0;
|
||||
uint8_t* scratch_buffers_[kNumScratchBuffers_];
|
||||
uint8_t* scratch_buffers_[kNumScratchBuffers];
|
||||
};
|
||||
|
||||
} // namespace micro
|
||||
|
@ -42,6 +42,7 @@ TfLiteRegistration Register_CONCATENATION();
|
||||
TfLiteRegistration Register_COS();
|
||||
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
|
||||
TfLiteRegistration Register_DEQUANTIZE();
|
||||
TfLiteRegistration Register_DETECTION_POSTPROCESS();
|
||||
TfLiteRegistration Register_EQUAL();
|
||||
TfLiteRegistration Register_FLOOR();
|
||||
TfLiteRegistration Register_FULLY_CONNECTED();
|
||||
|
@ -884,7 +884,11 @@ TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims,
|
||||
TfLiteTensor result = CreateTensor(dims, is_variable);
|
||||
result.type = kTfLiteFloat32;
|
||||
result.data.f = const_cast<float*>(data);
|
||||
result.bytes = ElementCount(*dims) * sizeof(float);
|
||||
if (dims == nullptr) {
|
||||
result.bytes = 0;
|
||||
} else {
|
||||
result.bytes = ElementCount(*dims) * sizeof(float);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ constexpr size_t kBufferAlignment = 16;
|
||||
// We store the pointer to the ith scratch buffer to implement the Request/Get
|
||||
// ScratchBuffer API for the tests. scratch_buffers_[i] will be the ith scratch
|
||||
// buffer and will still be allocated from within raw_arena_.
|
||||
constexpr int kNumScratchBuffers = 5;
|
||||
constexpr int kNumScratchBuffers = 12;
|
||||
uint8_t* scratch_buffers_[kNumScratchBuffers];
|
||||
int scratch_buffer_count_ = 0;
|
||||
|
||||
|
@ -281,6 +281,8 @@ third_party/gemmlowp/LICENSE \
|
||||
third_party/flatbuffers/include/flatbuffers/base.h \
|
||||
third_party/flatbuffers/include/flatbuffers/stl_emulation.h \
|
||||
third_party/flatbuffers/include/flatbuffers/flatbuffers.h \
|
||||
third_party/flatbuffers/include/flatbuffers/flexbuffers.h \
|
||||
third_party/flatbuffers/include/flatbuffers/util.h \
|
||||
third_party/flatbuffers/LICENSE.txt \
|
||||
third_party/ruy/ruy/profiler/instrumentation.h
|
||||
|
||||
|
@ -65,7 +65,8 @@ ifeq ($(TARGET), bluepill)
|
||||
tensorflow/lite/micro/micro_allocator_test.cc \
|
||||
tensorflow/lite/micro/memory_helpers_test.cc \
|
||||
tensorflow/lite/micro/memory_arena_threshold_test.cc \
|
||||
tensorflow/lite/micro/kernels/circular_buffer_test.cc
|
||||
tensorflow/lite/micro/kernels/circular_buffer_test.cc \
|
||||
tensorflow/lite/micro/kernels/detection_postprocess_test.cc
|
||||
MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS))
|
||||
|
||||
EXCLUDED_EXAMPLE_TESTS := \
|
||||
|
@ -71,8 +71,8 @@ ifeq ($(TARGET), stm32f4)
|
||||
tensorflow/lite/micro/recording_micro_allocator_test.cc \
|
||||
tensorflow/lite/micro/kernels/circular_buffer_test.cc \
|
||||
tensorflow/lite/micro/kernels/conv_test.cc \
|
||||
tensorflow/lite/micro/kernels/fully_connected_test.cc
|
||||
|
||||
tensorflow/lite/micro/kernels/fully_connected_test.cc \
|
||||
tensorflow/lite/micro/kernels/detection_postprocess_test.cc
|
||||
MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS))
|
||||
|
||||
EXCLUDED_EXAMPLE_TESTS := \
|
||||
|
Loading…
x
Reference in New Issue
Block a user