TFLu: Remove support for unknown output dims for detecion_pp
This commit is contained in:
parent
e51cec1a3a
commit
171fbdaaed
@ -15,16 +15,7 @@ limitations under the License.
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
|
||||||
#if !defined(__GNUC__) || defined(__CC_ARM) || defined(__clang__)
|
|
||||||
// TODO: remove this once this PR is merged and part of tensorflow downloads:
|
|
||||||
// https://github.com/google/flatbuffers/pull/6132
|
|
||||||
#pragma clang diagnostic push
|
|
||||||
#pragma clang diagnostic ignored "-Wdouble-promotion"
|
|
||||||
#include "flatbuffers/flexbuffers.h"
|
#include "flatbuffers/flexbuffers.h"
|
||||||
#pragma clang diagnostic pop
|
|
||||||
#else
|
|
||||||
#include "flatbuffers/flexbuffers.h"
|
|
||||||
#endif
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
@ -129,49 +120,8 @@ struct OpData {
|
|||||||
TfLiteQuantizationParams input_box_encodings;
|
TfLiteQuantizationParams input_box_encodings;
|
||||||
TfLiteQuantizationParams input_class_predictions;
|
TfLiteQuantizationParams input_class_predictions;
|
||||||
TfLiteQuantizationParams input_anchors;
|
TfLiteQuantizationParams input_anchors;
|
||||||
|
|
||||||
// In case out dimensions need to be allocated.
|
|
||||||
TfLiteIntArray* detection_boxes_dims;
|
|
||||||
TfLiteIntArray* detection_classes_dims;
|
|
||||||
TfLiteIntArray* detection_scores_dims;
|
|
||||||
TfLiteIntArray* num_detections_dims;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TfLiteStatus AllocateOutDimensions(TfLiteContext* context,
|
|
||||||
TfLiteTensor* tensor,
|
|
||||||
TfLiteIntArray** dims, int x, int y = 0,
|
|
||||||
int z = 0) {
|
|
||||||
int size = 1;
|
|
||||||
int size_dim = 1;
|
|
||||||
size = size * x;
|
|
||||||
|
|
||||||
if (y > 0) {
|
|
||||||
size = size * y;
|
|
||||||
size_dim++;
|
|
||||||
if (z > 0) {
|
|
||||||
size = size * z;
|
|
||||||
size_dim++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*dims = reinterpret_cast<TfLiteIntArray*>(context->AllocatePersistentBuffer(
|
|
||||||
context, TfLiteIntArrayGetSizeInBytes(size)));
|
|
||||||
|
|
||||||
(*dims)->size = size_dim;
|
|
||||||
(*dims)->data[0] = x;
|
|
||||||
if (y > 0) {
|
|
||||||
(*dims)->data[1] = y;
|
|
||||||
}
|
|
||||||
if (z > 0) {
|
|
||||||
(*dims)->data[2] = z;
|
|
||||||
}
|
|
||||||
|
|
||||||
TFLITE_DCHECK(tensor->type == kTfLiteFloat32);
|
|
||||||
tensor->bytes = size * sizeof(float);
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
OpData* op_data = nullptr;
|
OpData* op_data = nullptr;
|
||||||
|
|
||||||
@ -271,54 +221,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->RequestScratchBufferInArena(
|
context->RequestScratchBufferInArena(
|
||||||
context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
|
context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
|
||||||
|
|
||||||
// number of detected boxes
|
|
||||||
const int num_detected_boxes =
|
|
||||||
op_data->max_detections * op_data->max_classes_per_detection;
|
|
||||||
|
|
||||||
// Outputs: detection_boxes, detection_scores, detection_classes,
|
// Outputs: detection_boxes, detection_scores, detection_classes,
|
||||||
// num_detections
|
// num_detections
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||||
|
|
||||||
// Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
|
|
||||||
TfLiteTensor* detection_boxes =
|
|
||||||
GetOutput(context, node, kOutputTensorDetectionBoxes);
|
|
||||||
if (detection_boxes->dims->size == 0) {
|
|
||||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(context, detection_boxes,
|
|
||||||
&detection_boxes->dims,
|
|
||||||
1, num_detected_boxes, 4));
|
|
||||||
op_data->detection_boxes_dims = detection_boxes->dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output Tensor detection_classes: size is set to (1, num_detected_boxes)
|
|
||||||
TfLiteTensor* detection_classes =
|
|
||||||
GetOutput(context, node, kOutputTensorDetectionClasses);
|
|
||||||
if (detection_classes->dims->size == 0) {
|
|
||||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(
|
|
||||||
context, detection_classes,
|
|
||||||
&detection_classes->dims, 1, num_detected_boxes));
|
|
||||||
op_data->detection_classes_dims = detection_classes->dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output Tensor detection_scores: size is set to (1, num_detected_boxes)
|
|
||||||
TfLiteTensor* detection_scores =
|
|
||||||
GetOutput(context, node, kOutputTensorDetectionScores);
|
|
||||||
if (detection_scores->dims->size == 0) {
|
|
||||||
TF_LITE_ENSURE_STATUS(AllocateOutDimensions(
|
|
||||||
context, detection_scores,
|
|
||||||
&detection_scores->dims, 1, num_detected_boxes));
|
|
||||||
op_data->detection_scores_dims = detection_scores->dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output Tensor num_detections: size is set to 1
|
|
||||||
TfLiteTensor* num_detections =
|
|
||||||
GetOutput(context, node, kOutputTensorNumDetections);
|
|
||||||
if (num_detections->dims->size == 0) {
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
|
||||||
AllocateOutDimensions(context, num_detections,
|
|
||||||
&num_detections->dims, 1));
|
|
||||||
op_data->num_detections_dims = num_detections->dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -870,28 +776,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
||||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
TfLiteEvalTensor* detection_boxes =
|
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
|
|
||||||
if (detection_boxes->dims->size == 0) {
|
|
||||||
detection_boxes->dims = op_data->detection_boxes_dims;
|
|
||||||
}
|
|
||||||
TfLiteEvalTensor* detection_classes =
|
|
||||||
tflite::micro::GetEvalOutput(context, node,
|
|
||||||
kOutputTensorDetectionClasses);
|
|
||||||
if (detection_classes->dims->size == 0) {
|
|
||||||
detection_classes->dims = op_data->detection_classes_dims;
|
|
||||||
}
|
|
||||||
TfLiteEvalTensor* detection_scores =
|
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
|
|
||||||
if (detection_scores->dims->size == 0) {
|
|
||||||
detection_scores->dims = op_data->detection_scores_dims;
|
|
||||||
}
|
|
||||||
TfLiteEvalTensor* num_detections =
|
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
|
|
||||||
if (num_detections->dims->size == 0) {
|
|
||||||
num_detections->dims = op_data->num_detections_dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
// These two functions correspond to two blocks in the Object Detection model.
|
// These two functions correspond to two blocks in the Object Detection model.
|
||||||
// In future, we would like to break the custom op in two blocks, which is
|
// In future, we would like to break the custom op in two blocks, which is
|
||||||
// currently not feasible because we would like to input quantized inputs
|
// currently not feasible because we would like to input quantized inputs
|
||||||
|
@ -13,22 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if !defined(__GNUC__) || defined(__CC_ARM) || defined(__clang__)
|
|
||||||
// TODO: remove this once this PR is merged and part of tensorflow downloads:
|
|
||||||
// https://github.com/google/flatbuffers/pull/6132
|
|
||||||
#pragma clang diagnostic push
|
|
||||||
#pragma clang diagnostic ignored "-Wdouble-promotion"
|
|
||||||
#include "flatbuffers/flexbuffers.h"
|
#include "flatbuffers/flexbuffers.h"
|
||||||
#pragma clang diagnostic pop
|
|
||||||
#else
|
|
||||||
#include "flatbuffers/flexbuffers.h"
|
|
||||||
#endif
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
#include "tensorflow/lite/micro/test_helpers.h"
|
||||||
|
|
||||||
// See: tensorflow/lite/micro/kernels/detection_postprocess_test/readme
|
// See: tensorflow/lite/micro/kernels/detection_postprocess_test/readme
|
||||||
#include "tensorflow/lite/micro/kernels/detection_postprocess_test/flexbuffers_generated_data.h"
|
#include "tensorflow/lite/micro/kernels/detection_postprocess_test/flexbuffers_generated_data.h"
|
||||||
@ -484,20 +474,4 @@ TF_LITE_MICRO_TEST(
|
|||||||
/* input3 min/max */ 0.0, 100.5);
|
/* input3 min/max */ 0.0, 100.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(DetectionPostprocessFloatFastNMSUndefinedOutputDimensions) {
|
|
||||||
float output_data1[12];
|
|
||||||
float output_data2[3];
|
|
||||||
float output_data3[3];
|
|
||||||
float output_data4[1];
|
|
||||||
|
|
||||||
tflite::testing::TestDetectionPostprocess(
|
|
||||||
tflite::testing::kInputShape1, tflite::testing::kInputData1,
|
|
||||||
tflite::testing::kInputShape2, tflite::testing::kInputData2,
|
|
||||||
tflite::testing::kInputShape3, tflite::testing::kInputData3, nullptr,
|
|
||||||
output_data1, nullptr, output_data2, nullptr, output_data3, nullptr,
|
|
||||||
output_data4, tflite::testing::kGolden1, tflite::testing::kGolden2,
|
|
||||||
tflite::testing::kGolden3, tflite::testing::kGolden4,
|
|
||||||
/* tolerance */ 0, /* Use regular NMS: */ false);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_LITE_MICRO_TESTS_END
|
TF_LITE_MICRO_TESTS_END
|
||||||
|
Loading…
Reference in New Issue
Block a user