Fix TODOs in detection_postprocess

PiperOrigin-RevId: 350801426
Change-Id: Ic7375cffcd55ddc3cd564be78c3eedd47a270c52
This commit is contained in:
T.J. Alumbaugh 2021-01-08 11:25:23 -08:00 committed by TensorFlower Gardener
parent 8004de62d2
commit a5ce810d42

View File

@ -198,7 +198,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
num_detections->type = kTfLiteFloat32;
// TODO (chowdhery): Make it a scalar when available
SetTensorSizes(context, num_detections, {1});
// Temporary tensors
@ -267,14 +266,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
template <class T>
T ReInterpretTensor(const TfLiteTensor* tensor) {
// TODO (chowdhery): check float
const float* tensor_base = GetTensorData<float>(tensor);
return reinterpret_cast<T>(tensor_base);
}
template <class T>
T ReInterpretTensor(TfLiteTensor* tensor) {
// TODO (chowdhery): check float
float* tensor_base = GetTensorData<float>(tensor);
return reinterpret_cast<T>(tensor_base);
}
@ -319,6 +316,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
const float* boxes =
&(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
anchor =
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
break;
@ -351,6 +349,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index];
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
box.ymin = ycenter - half_h;
box.xmin = xcenter - half_w;
@ -438,11 +437,12 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
(intersection_over_union_threshold <= 1.0f));
// Validate boxes
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
// threshold scores
std::vector<int> keep_indices;
// TODO (chowdhery): Remove the dynamic allocation and replace it
// TODO(b/177068807): Remove the dynamic allocation and replace it
// with temporaries, esp for std::vector<float>
std::vector<float> keep_scores;
SelectDetectionsAboveScoreThreshold(
@ -476,6 +476,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
}
for (int j = i + 1; j < num_boxes_kept; ++j) {
if (active_box_candidate[j] == 1) {
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
float intersection_over_union = ComputeIntersectionOverUnion(
decoded_boxes, keep_indices[sorted_indices[i]],
keep_indices[sorted_indices[j]]);
@ -607,6 +608,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
const float selected_score =
scores_after_regular_non_max_suppression[output_box_index];
// detection_boxes
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
ReInterpretTensor<const BoxCornerEncoding*>(
decoded_boxes)[anchor_index];
@ -615,6 +618,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
// detection_scores
GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
} else {
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>(
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
// detection_classes
@ -705,6 +709,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
for (int col = 0; col < num_categories_per_anchor; ++col) {
int box_offset = num_categories_per_anchor * output_box_index + col;
// detection_boxes
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
ReInterpretTensor<const BoxCornerEncoding*>(
decoded_boxes)[selected_index];
@ -782,7 +788,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(chowdhery): Generalize for any batch size
// TODO(b/177068051): Generalize for any batch size.
TF_LITE_ENSURE(context, (kBatchSize == 1));
auto* op_data = static_cast<OpData*>(node->user_data);
// These two functions correspond to two blocks in the Object Detection model.