diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc index f746ad19498..bbd47081804 100644 --- a/tensorflow/lite/kernels/detection_postprocess.cc +++ b/tensorflow/lite/kernels/detection_postprocess.cc @@ -198,7 +198,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { GetOutputSafe(context, node, kOutputTensorNumDetections, &num_detections)); num_detections->type = kTfLiteFloat32; - // TODO (chowdhery): Make it a scalar when available SetTensorSizes(context, num_detections, {1}); // Temporary tensors @@ -267,14 +266,12 @@ void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, template T ReInterpretTensor(const TfLiteTensor* tensor) { - // TODO (chowdhery): check float const float* tensor_base = GetTensorData(tensor); return reinterpret_cast(tensor_base); } template T ReInterpretTensor(TfLiteTensor* tensor) { - // TODO (chowdhery): check float float* tensor_base = GetTensorData(tensor); return reinterpret_cast(tensor_base); } @@ -319,6 +316,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, const float* boxes = &(GetTensorData(input_box_encodings)[box_encoding_idx]); box_centersize = *reinterpret_cast(boxes); + TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32); anchor = ReInterpretTensor(input_anchors)[idx]; break; @@ -351,6 +349,7 @@ TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; + TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); auto& box = ReInterpretTensor(decoded_boxes)[idx]; box.ymin = ycenter - half_h; box.xmin = xcenter - half_w; @@ -438,11 +437,12 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && (intersection_over_union_threshold <= 1.0f)); // Validate boxes + TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); // threshold scores std::vector keep_indices; - // TODO (chowdhery): Remove the dynamic allocation and replace it + // TODO(b/177068807): Remove the dynamic allocation and replace it // with temporaries, esp for std::vector std::vector keep_scores; SelectDetectionsAboveScoreThreshold( @@ -476,6 +476,7 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( } for (int j = i + 1; j < num_boxes_kept; ++j) { if (active_box_candidate[j] == 1) { + TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); float intersection_over_union = ComputeIntersectionOverUnion( decoded_boxes, keep_indices[sorted_indices[i]], keep_indices[sorted_indices[j]]); @@ -607,6 +608,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, const float selected_score = scores_after_regular_non_max_suppression[output_box_index]; // detection_boxes + TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); ReInterpretTensor(detection_boxes)[output_box_index] = ReInterpretTensor( decoded_boxes)[anchor_index]; @@ -615,6 +618,7 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, // detection_scores GetTensorData(detection_scores)[output_box_index] = selected_score; } else { + TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); ReInterpretTensor( detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f}; // detection_classes @@ -705,6 +709,8 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, for (int col = 0; col < num_categories_per_anchor; ++col) { int box_offset = num_categories_per_anchor * output_box_index + col; // detection_boxes + TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); ReInterpretTensor(detection_boxes)[box_offset] = ReInterpretTensor( decoded_boxes)[selected_index]; @@ -782,7 +788,7 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - // TODO(chowdhery): Generalize for any batch size + // TODO(b/177068051): Generalize for any batch size. TF_LITE_ENSURE(context, (kBatchSize == 1)); auto* op_data = static_cast(node->user_data); // These two functions correspond to two blocks in the Object Detection model.