Fix TODOs in detection_postprocess

PiperOrigin-RevId: 350801426
Change-Id: Ic7375cffcd55ddc3cd564be78c3eedd47a270c52
This commit is contained in:
T.J. Alumbaugh 2021-01-08 11:25:23 -08:00 committed by TensorFlower Gardener
parent 8004de62d2
commit a5ce810d42

View File

@ -198,7 +198,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
GetOutputSafe(context, node, kOutputTensorNumDetections, GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections)); &num_detections));
num_detections->type = kTfLiteFloat32; num_detections->type = kTfLiteFloat32;
// TODO (chowdhery): Make it a scalar when available
SetTensorSizes(context, num_detections, {1}); SetTensorSizes(context, num_detections, {1});
// Temporary tensors // Temporary tensors
@ -267,14 +266,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
template <class T> template <class T>
T ReInterpretTensor(const TfLiteTensor* tensor) { T ReInterpretTensor(const TfLiteTensor* tensor) {
// TODO (chowdhery): check float
const float* tensor_base = GetTensorData<float>(tensor); const float* tensor_base = GetTensorData<float>(tensor);
return reinterpret_cast<T>(tensor_base); return reinterpret_cast<T>(tensor_base);
} }
template <class T> template <class T>
T ReInterpretTensor(TfLiteTensor* tensor) { T ReInterpretTensor(TfLiteTensor* tensor) {
// TODO (chowdhery): check float
float* tensor_base = GetTensorData<float>(tensor); float* tensor_base = GetTensorData<float>(tensor);
return reinterpret_cast<T>(tensor_base); return reinterpret_cast<T>(tensor_base);
} }
@ -319,6 +316,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
const float* boxes = const float* boxes =
&(GetTensorData<float>(input_box_encodings)[box_encoding_idx]); &(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes); box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
anchor = anchor =
ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx]; ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
break; break;
@ -351,6 +349,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* decoded_boxes = TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index]; &context->tensors[op_data->decoded_boxes_index];
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx]; auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
box.ymin = ycenter - half_h; box.ymin = ycenter - half_h;
box.xmin = xcenter - half_w; box.xmin = xcenter - half_w;
@ -438,11 +437,12 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
(intersection_over_union_threshold <= 1.0f)); (intersection_over_union_threshold <= 1.0f));
// Validate boxes // Validate boxes
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
// threshold scores // threshold scores
std::vector<int> keep_indices; std::vector<int> keep_indices;
// TODO (chowdhery): Remove the dynamic allocation and replace it // TODO(b/177068807): Remove the dynamic allocation and replace it
// with temporaries, esp for std::vector<float> // with temporaries, esp for std::vector<float>
std::vector<float> keep_scores; std::vector<float> keep_scores;
SelectDetectionsAboveScoreThreshold( SelectDetectionsAboveScoreThreshold(
@ -476,6 +476,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
} }
for (int j = i + 1; j < num_boxes_kept; ++j) { for (int j = i + 1; j < num_boxes_kept; ++j) {
if (active_box_candidate[j] == 1) { if (active_box_candidate[j] == 1) {
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
float intersection_over_union = ComputeIntersectionOverUnion( float intersection_over_union = ComputeIntersectionOverUnion(
decoded_boxes, keep_indices[sorted_indices[i]], decoded_boxes, keep_indices[sorted_indices[i]],
keep_indices[sorted_indices[j]]); keep_indices[sorted_indices[j]]);
@ -607,6 +608,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
const float selected_score = const float selected_score =
scores_after_regular_non_max_suppression[output_box_index]; scores_after_regular_non_max_suppression[output_box_index];
// detection_boxes // detection_boxes
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] = ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
ReInterpretTensor<const BoxCornerEncoding*>( ReInterpretTensor<const BoxCornerEncoding*>(
decoded_boxes)[anchor_index]; decoded_boxes)[anchor_index];
@ -615,6 +618,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
// detection_scores // detection_scores
GetTensorData<float>(detection_scores)[output_box_index] = selected_score; GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
} else { } else {
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>( ReInterpretTensor<BoxCornerEncoding*>(
detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
// detection_classes // detection_classes
@ -705,6 +709,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
for (int col = 0; col < num_categories_per_anchor; ++col) { for (int col = 0; col < num_categories_per_anchor; ++col) {
int box_offset = num_categories_per_anchor * output_box_index + col; int box_offset = num_categories_per_anchor * output_box_index + col;
// detection_boxes // detection_boxes
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] = ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
ReInterpretTensor<const BoxCornerEncoding*>( ReInterpretTensor<const BoxCornerEncoding*>(
decoded_boxes)[selected_index]; decoded_boxes)[selected_index];
@ -782,7 +788,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(chowdhery): Generalize for any batch size // TODO(b/177068051): Generalize for any batch size.
TF_LITE_ENSURE(context, (kBatchSize == 1)); TF_LITE_ENSURE(context, (kBatchSize == 1));
auto* op_data = static_cast<OpData*>(node->user_data); auto* op_data = static_cast<OpData*>(node->user_data);
// These two functions correspond to two blocks in the Object Detection model. // These two functions correspond to two blocks in the Object Detection model.