Fix TODOs in detection_postprocess
PiperOrigin-RevId: 350801426 Change-Id: Ic7375cffcd55ddc3cd564be78c3eedd47a270c52
This commit is contained in:
parent
8004de62d2
commit
a5ce810d42
@ -198,7 +198,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
||||||
&num_detections));
|
&num_detections));
|
||||||
num_detections->type = kTfLiteFloat32;
|
num_detections->type = kTfLiteFloat32;
|
||||||
// TODO (chowdhery): Make it a scalar when available
|
|
||||||
SetTensorSizes(context, num_detections, {1});
|
SetTensorSizes(context, num_detections, {1});
|
||||||
|
|
||||||
// Temporary tensors
|
// Temporary tensors
|
||||||
@ -267,14 +266,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
|
|||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
T ReInterpretTensor(const TfLiteTensor* tensor) {
|
T ReInterpretTensor(const TfLiteTensor* tensor) {
|
||||||
// TODO (chowdhery): check float
|
|
||||||
const float* tensor_base = GetTensorData<float>(tensor);
|
const float* tensor_base = GetTensorData<float>(tensor);
|
||||||
return reinterpret_cast<T>(tensor_base);
|
return reinterpret_cast<T>(tensor_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
T ReInterpretTensor(TfLiteTensor* tensor) {
|
T ReInterpretTensor(TfLiteTensor* tensor) {
|
||||||
// TODO (chowdhery): check float
|
|
||||||
float* tensor_base = GetTensorData<float>(tensor);
|
float* tensor_base = GetTensorData<float>(tensor);
|
||||||
return reinterpret_cast<T>(tensor_base);
|
return reinterpret_cast<T>(tensor_base);
|
||||||
}
|
}
|
||||||
@ -319,6 +316,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
|||||||
const float* boxes =
|
const float* boxes =
|
||||||
&(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
|
&(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
|
||||||
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
|
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
|
||||||
anchor =
|
anchor =
|
||||||
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
|
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
|
||||||
break;
|
break;
|
||||||
@ -351,6 +349,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
|||||||
|
|
||||||
TfLiteTensor* decoded_boxes =
|
TfLiteTensor* decoded_boxes =
|
||||||
&context->tensors[op_data->decoded_boxes_index];
|
&context->tensors[op_data->decoded_boxes_index];
|
||||||
|
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||||
auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
|
auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
|
||||||
box.ymin = ycenter - half_h;
|
box.ymin = ycenter - half_h;
|
||||||
box.xmin = xcenter - half_w;
|
box.xmin = xcenter - half_w;
|
||||||
@ -438,11 +437,12 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
|||||||
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
|
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
|
||||||
(intersection_over_union_threshold <= 1.0f));
|
(intersection_over_union_threshold <= 1.0f));
|
||||||
// Validate boxes
|
// Validate boxes
|
||||||
|
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||||
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
|
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
|
||||||
|
|
||||||
// threshold scores
|
// threshold scores
|
||||||
std::vector<int> keep_indices;
|
std::vector<int> keep_indices;
|
||||||
// TODO (chowdhery): Remove the dynamic allocation and replace it
|
// TODO(b/177068807): Remove the dynamic allocation and replace it
|
||||||
// with temporaries, esp for std::vector<float>
|
// with temporaries, esp for std::vector<float>
|
||||||
std::vector<float> keep_scores;
|
std::vector<float> keep_scores;
|
||||||
SelectDetectionsAboveScoreThreshold(
|
SelectDetectionsAboveScoreThreshold(
|
||||||
@ -476,6 +476,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
|||||||
}
|
}
|
||||||
for (int j = i + 1; j < num_boxes_kept; ++j) {
|
for (int j = i + 1; j < num_boxes_kept; ++j) {
|
||||||
if (active_box_candidate[j] == 1) {
|
if (active_box_candidate[j] == 1) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||||
float intersection_over_union = ComputeIntersectionOverUnion(
|
float intersection_over_union = ComputeIntersectionOverUnion(
|
||||||
decoded_boxes, keep_indices[sorted_indices[i]],
|
decoded_boxes, keep_indices[sorted_indices[i]],
|
||||||
keep_indices[sorted_indices[j]]);
|
keep_indices[sorted_indices[j]]);
|
||||||
@ -607,6 +608,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
|||||||
const float selected_score =
|
const float selected_score =
|
||||||
scores_after_regular_non_max_suppression[output_box_index];
|
scores_after_regular_non_max_suppression[output_box_index];
|
||||||
// detection_boxes
|
// detection_boxes
|
||||||
|
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
|
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
|
||||||
ReInterpretTensor<const BoxCornerEncoding*>(
|
ReInterpretTensor<const BoxCornerEncoding*>(
|
||||||
decoded_boxes)[anchor_index];
|
decoded_boxes)[anchor_index];
|
||||||
@ -615,6 +618,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
|||||||
// detection_scores
|
// detection_scores
|
||||||
GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
|
GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
|
||||||
} else {
|
} else {
|
||||||
|
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||||
ReInterpretTensor<BoxCornerEncoding*>(
|
ReInterpretTensor<BoxCornerEncoding*>(
|
||||||
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
|
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
// detection_classes
|
// detection_classes
|
||||||
@ -705,6 +709,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
|
|||||||
for (int col = 0; col < num_categories_per_anchor; ++col) {
|
for (int col = 0; col < num_categories_per_anchor; ++col) {
|
||||||
int box_offset = num_categories_per_anchor * output_box_index + col;
|
int box_offset = num_categories_per_anchor * output_box_index + col;
|
||||||
// detection_boxes
|
// detection_boxes
|
||||||
|
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
|
||||||
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
|
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
|
||||||
ReInterpretTensor<const BoxCornerEncoding*>(
|
ReInterpretTensor<const BoxCornerEncoding*>(
|
||||||
decoded_boxes)[selected_index];
|
decoded_boxes)[selected_index];
|
||||||
@ -782,7 +788,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
// TODO(chowdhery): Generalize for any batch size
|
// TODO(b/177068051): Generalize for any batch size.
|
||||||
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
TF_LITE_ENSURE(context, (kBatchSize == 1));
|
||||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||||
// These two functions correspond to two blocks in the Object Detection model.
|
// These two functions correspond to two blocks in the Object Detection model.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user