From 8e6ee2cfc60f454e5aeef535f86d44eedcf192fa Mon Sep 17 00:00:00 2001 From: Teng Lu Date: Tue, 17 Nov 2020 09:04:54 +0800 Subject: [PATCH] Improve CNMS performance by removing unnecessary allocation. --- .../core/kernels/image/non_max_suppression_op.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/kernels/image/non_max_suppression_op.cc b/tensorflow/core/kernels/image/non_max_suppression_op.cc index 1c4166058fb..c66be3836a3 100644 --- a/tensorflow/core/kernels/image/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/image/non_max_suppression_op.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -33,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace { @@ -320,12 +320,6 @@ void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data, } } - // Copy class_boxes_data to a tensor - TensorShape boxesShape({num_boxes, 4}); - Tensor boxes(DT_FLOAT, boxesShape); - std::copy_n(class_boxes_data.begin(), class_boxes_data.size(), - boxes.unaligned_flat().data()); - // Do NMS, get the candidate indices of form vector // Data structure for selection candidate in NMS. struct Candidate { @@ -347,9 +341,10 @@ void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data, Candidate next_candidate; std::sort(candidate_vector.begin(), candidate_vector.end(), cmp); - const Tensor const_boxes = boxes; - typename TTypes::ConstTensor boxes_data_t = - const_boxes.tensor(); + // Move class_boxes_data to a tensor + Eigen::array boxesShape = {num_boxes, 4}; + typename TTypes::ConstTensor boxes_data_t(class_boxes_data.data(), + boxesShape); int candidate_idx = 0; float iou; while (selected.size() < size_per_class &&