Fix TODOs in detection_postprocess
PiperOrigin-RevId: 350801426 Change-Id: Ic7375cffcd55ddc3cd564be78c3eedd47a270c52
This commit is contained in:
parent
8004de62d2
commit
a5ce810d42
@ -198,7 +198,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
||||
&num_detections));
|
||||
num_detections->type = kTfLiteFloat32;
|
||||
// TODO (chowdhery): Make it a scalar when available
|
||||
SetTensorSizes(context, num_detections, {1});
|
||||
|
||||
// Temporary tensors
|
||||
@ -267,14 +266,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
|
||||
|
||||
template <class T>
|
||||
T ReInterpretTensor(const TfLiteTensor* tensor) {
|
||||
// TODO (chowdhery): check float
|
||||
const float* tensor_base = GetTensorData<float>(tensor);
|
||||
return reinterpret_cast<T>(tensor_base);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T ReInterpretTensor(TfLiteTensor* tensor) {
|
||||
// TODO (chowdhery): check float
|
||||
float* tensor_base = GetTensorData<float>(tensor);
|
||||
return reinterpret_cast<T>(tensor_base);
|
||||
}
|
||||
@ -319,6 +316,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||
const float* boxes =
|
||||
&(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
|
||||
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
|
||||
TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
|
||||
anchor =
|
||||
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
|
||||
break;
|
||||
@ -351,6 +349,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
TfLiteTensor* decoded_boxes =
|
||||
&context->tensors[op_data->decoded_boxes_index];
|
||||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||
auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
|
||||
box.ymin = ycenter - half_h;
|
||||
box.xmin = xcenter - half_w;
|
||||
@ -438,11 +437,12 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
||||
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
|
||||
(intersection_over_union_threshold <= 1.0f));
|
||||
// Validate boxes
|
||||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
|
||||
|
||||
// threshold scores
|
||||
std::vector<int> keep_indices;
|
||||
// TODO (chowdhery): Remove the dynamic allocation and replace it
|
||||
// TODO(b/177068807): Remove the dynamic allocation and replace it
|
||||
// with temporaries, esp for std::vector<float>
|
||||
std::vector<float> keep_scores;
|
||||
SelectDetectionsAboveScoreThreshold(
|
||||
@ -476,6 +476,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
||||
}
|
||||
for (int j = i + 1; j < num_boxes_kept; ++j) {
|
||||
if (active_box_candidate[j] == 1) {
|
||||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||
float intersection_over_union = ComputeIntersectionOverUnion(
|
||||
decoded_boxes, keep_indices[sorted_indices[i]],
|
||||
keep_indices[sorted_indices[j]]);
|
||||
@ -607,6 +608,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
||||
const float selected_score =
|
||||
scores_after_regular_non_max_suppression[output_box_index];
|
||||
// detection_boxes
|
||||
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
|
||||
ReInterpretTensor<const BoxCornerEncoding*>(
|
||||
decoded_boxes)[anchor_index];
|
||||
@ -615,6 +618,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
||||
// detection_scores
|
||||
GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||
ReInterpretTensor<BoxCornerEncoding*>(
|
||||
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
// detection_classes
|
||||
@ -705,6 +709,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
|
||||
for (int col = 0; col < num_categories_per_anchor; ++col) {
|
||||
int box_offset = num_categories_per_anchor * output_box_index + col;
|
||||
// detection_boxes
|
||||
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
|
||||
ReInterpretTensor<const BoxCornerEncoding*>(
|
||||
decoded_boxes)[selected_index];
|
||||
@ -782,7 +788,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(chowdhery): Generalize for any batch size
|
||||
// TODO(b/177068051): Generalize for any batch size.
|
||||
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
// These two functions correspond to two blocks in the Object Detection model.
|
||||
|
Loading…
x
Reference in New Issue
Block a user