[Intel MKL] Optimize combinedNMS performance

This commit is contained in:
Li, Guizi 2020-02-15 13:55:35 +08:00
parent fe03adf6e6
commit 0be4b608c0
3 changed files with 339 additions and 185 deletions

View File

@ -3245,6 +3245,22 @@ tf_cc_tests(
],
)
tf_cc_test(
name = "non_max_suppression_op_benchmark_test",
srcs = ["non_max_suppression_op_benchmark_test.cc"],
deps = [
":image",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "resize_bilinear_op_test",
srcs = ["resize_bilinear_op_test.cc"],

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <queue>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -33,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace {
@ -284,14 +284,179 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
}
}
struct ResultCandidate {
int box_index;
float score;
int class_idx;
float box_coord[4];
};
void DoNMS(int batch_idx, int class_idx, const float* boxes_data,
const float* scores_data, int num_boxes, int q, int num_classes,
const int size_per_class, const float score_threshold,
const float iou_threshold,
std::vector<ResultCandidate>& result_candidate_vec) {
std::vector<float> class_scores_data;
class_scores_data.reserve(num_boxes);
std::vector<float> class_boxes_data;
class_boxes_data.reserve(num_boxes * 4);
for (int box_idx = 0; box_idx < num_boxes; ++box_idx) {
class_scores_data.push_back(scores_data[box_idx * num_classes + class_idx]);
for (int cid = 0; cid < 4; ++cid) {
if (q > 1) {
class_boxes_data.push_back(
boxes_data[(box_idx * q + class_idx) * 4 + cid]);
} else {
class_boxes_data.push_back(boxes_data[box_idx * 4 + cid]);
}
}
}
// Copy class_boxes_data to a tensor
TensorShape boxesShape({num_boxes, 4});
Tensor boxes(DT_FLOAT, boxesShape);
std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
boxes.unaligned_flat<float>().data());
// Do NMS, get the candidate indices of form vector<int>
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
float score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
return bs_i.score > bs_j.score;
};
std::vector<Candidate> candidate_vector;
for (int i = 0; i < class_scores_data.size(); ++i) {
if (class_scores_data[i] > score_threshold) {
candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
}
}
std::vector<int> selected;
std::vector<float> selected_boxes;
Candidate next_candidate;
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
const Tensor const_boxes = boxes;
typename TTypes<float, 2>::ConstTensor boxes_data_1 =
const_boxes.tensor<float, 2>();
int candidate_idx = 0;
float iou;
while (selected.size() < size_per_class &&
candidate_idx < candidate_vector.size()) {
next_candidate = candidate_vector[candidate_idx++];
// Overlapping boxes are likely to have similar scores,
// therefore we iterate through the previously selected boxes backwards
// in order to see if `next_candidate` should be suppressed.
bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
iou = IOU<float>(boxes_data_1, next_candidate.box_index, selected[j]);
if (iou > iou_threshold) {
should_select = false;
break;
}
}
if (should_select) {
// Add the selected box to the result candidate. Sorted by score
int id = next_candidate.box_index;
auto& rc =
result_candidate_vec[selected.size() + size_per_class * class_idx];
selected.push_back(next_candidate.box_index);
rc.box_index = next_candidate.box_index;
rc.score = next_candidate.score;
rc.class_idx = class_idx;
rc.box_coord[0] = boxes_data_1(id, 0);
rc.box_coord[1] = boxes_data_1(id, 1);
rc.box_coord[2] = boxes_data_1(id, 2);
rc.box_coord[3] = boxes_data_1(id, 3);
}
}
}
void SelectResultPerBatch(std::vector<float>& nmsed_boxes,
std::vector<float>& nmsed_scores,
std::vector<float>& nmsed_classes,
std::vector<ResultCandidate>& result_candidate_vec,
std::vector<int>& final_valid_detections,
const int batch_idx, int total_size_per_batch,
bool pad_per_class, int max_size_per_batch,
bool clip_boxes, int per_batch_size) {
auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
return rc_i.score > rc_j.score;
};
std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
int max_detections = 0;
int result_candidate_size =
std::count_if(result_candidate_vec.begin(), result_candidate_vec.end(),
[](ResultCandidate rc) { return rc.box_index > -1; });
// If pad_per_class is false, we always pad to max_total_size
if (!pad_per_class) {
max_detections = std::min(result_candidate_size, total_size_per_batch);
} else {
max_detections = std::min(per_batch_size, result_candidate_size);
}
final_valid_detections[batch_idx] = max_detections;
int curr_total_size = max_detections;
int result_idx = 0;
// Pick the top max_detections values
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
// Add to final output vectors
if (clip_boxes) {
const float box_min = 0.0;
const float box_max = 1.0;
nmsed_boxes.push_back(
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
nmsed_boxes.push_back(
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
nmsed_boxes.push_back(
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
nmsed_boxes.push_back(
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
} else {
nmsed_boxes.push_back(next_candidate.box_coord[0]);
nmsed_boxes.push_back(next_candidate.box_coord[1]);
nmsed_boxes.push_back(next_candidate.box_coord[2]);
nmsed_boxes.push_back(next_candidate.box_coord[3]);
}
nmsed_scores.push_back(next_candidate.score);
nmsed_classes.push_back(next_candidate.class_idx);
curr_total_size--;
}
nmsed_boxes.resize(per_batch_size * 4, 0);
nmsed_scores.resize(per_batch_size, 0);
nmsed_classes.resize(per_batch_size, 0);
}
void BatchedNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
int num_boxes, const int max_size_per_class, const int total_size_per_batch,
const float score_threshold, const float iou_threshold,
bool pad_per_class = false, bool clip_boxes = true) {
int q = inp_boxes.dim_size(2);
int num_classes = inp_scores.dim_size(2);
const int num_batches = inp_boxes.dim_size(0);
int num_classes = inp_scores.dim_size(2);
int q = inp_boxes.dim_size(2);
const float* scores_data =
const_cast<float*>(inp_scores.flat<float>().data());
const float* boxes_data = const_cast<float*>(inp_boxes.flat<float>().data());
int boxes_per_batch = num_boxes * q * 4;
int scores_per_batch = num_boxes * num_classes;
const int size_per_class = std::min(max_size_per_class, num_boxes);
std::vector<std::vector<ResultCandidate>> result_candidate_vec(
num_batches,
std::vector<ResultCandidate>(size_per_class * num_classes,
{-1, -1.0, -1, {0.0, 0.0, 0.0, 0.0}}));
// [num_batches, per_batch_size * 4]
std::vector<std::vector<float>> nmsed_boxes(num_batches);
@ -300,166 +465,72 @@ void BatchedNonMaxSuppressionOp(
// [num_batches, per_batch_size]
std::vector<std::vector<float>> nmsed_classes(num_batches);
// [num_batches]
std::vector<int> final_valid_detections;
std::vector<int> final_valid_detections(num_batches);
auto shard_nms = [&](int begin, int end) {
int boxes_per_batch = num_boxes * q * 4;
int scores_per_batch = num_boxes * num_classes;
for (int idx = begin; idx < end; ++idx) {
int batch_idx = idx / num_classes;
int class_idx = idx % num_classes;
DoNMS(batch_idx, class_idx, boxes_data + boxes_per_batch * batch_idx,
scores_data + scores_per_batch * batch_idx, num_boxes, q,
num_classes, size_per_class, score_threshold, iou_threshold,
result_candidate_vec[batch_idx]);
}
};
int length = num_batches * num_classes;
// Input data boxes_data, scores_data
int input_bytes = length * num_boxes * 5;
int output_bytes = length * num_boxes * 5;
int compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
Eigen::TensorOpCost::MulCost<int>() * 2 +
Eigen::TensorOpCost::AddCost<float>() * 10 +
Eigen::TensorOpCost::MulCost<float>() * 6 +
Eigen::TensorOpCost::DivCost<float>()) *
length;
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
const CPUDevice& d = context->eigen_device<CPUDevice>();
d.parallelFor(length, cost, shard_nms);
int per_batch_size = total_size_per_batch;
// perform non_max_suppression operation for each batch independently
for (int batch = 0; batch < num_batches; ++batch) {
// dims of per_batch_boxes [num_boxes, q, 4]
Tensor per_batch_boxes = inp_boxes.Slice(batch, batch + 1);
// dims of per_batch_scores [num_boxes, num_classes]
Tensor per_batch_scores = inp_scores.Slice(batch, batch + 1);
struct ResultCandidate {
int box_index;
float score;
int class_idx;
float box_coord[4];
};
std::vector<ResultCandidate> result_candidate_vec;
float* scores_data = per_batch_scores.unaligned_flat<float>().data();
float* boxes_data = per_batch_boxes.unaligned_flat<float>().data();
// Iterate through all classes
for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
std::vector<float> class_scores_data;
class_scores_data.reserve(num_boxes);
std::vector<float> class_boxes_data;
class_boxes_data.reserve(num_boxes * 4);
for (int box = 0; box < num_boxes; ++box) {
// Get the scores per class
// class_scores_data dim is [num_boxes].
class_scores_data.push_back(scores_data[box * num_classes + class_idx]);
for (int cid = 0; cid < 4; ++cid) {
if (q > 1) {
// Get the boxes per class. class_boxes_data dims is [num_boxes, 4]
class_boxes_data.push_back(
boxes_data[(box * q + class_idx) * 4 + cid]);
} else {
class_boxes_data.push_back(boxes_data[box * 4 + cid]);
}
}
}
// Copy class_boxes_data to a tensor
TensorShape boxesShape({num_boxes, 4});
Tensor boxes(per_batch_boxes.dtype(), boxesShape);
std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
boxes.unaligned_flat<float>().data());
const int size_per_class = std::min(max_size_per_class, num_boxes);
// Do NMS, get the candidate indices of form vector<int>
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
float score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
return bs_i.score > bs_j.score;
};
std::vector<Candidate> candidate_vector;
for (int i = 0; i < class_scores_data.size(); ++i) {
if (class_scores_data[i] > score_threshold) {
candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
}
}
std::vector<int> selected;
Candidate next_candidate;
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
const Tensor const_boxes = boxes;
typename TTypes<float, 2>::ConstTensor boxes_data =
const_boxes.tensor<float, 2>();
int candidate_idx = 0;
float iou;
while (selected.size() < size_per_class &&
candidate_idx < candidate_vector.size()) {
next_candidate = candidate_vector[candidate_idx++];
// Overlapping boxes are likely to have similar scores,
// therefore we iterate through the previously selected boxes backwards
// in order to see if `next_candidate` should be suppressed.
bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
iou = IOU<float>(boxes_data, next_candidate.box_index, selected[j]);
if (iou > iou_threshold) {
should_select = false;
break;
}
}
if (should_select) {
selected.push_back(next_candidate.box_index);
// Add the selected box to the result candidate. Sorted by score
int id = next_candidate.box_index;
ResultCandidate rc = {next_candidate.box_index,
next_candidate.score,
class_idx,
{boxes_data(id, 0), boxes_data(id, 1),
boxes_data(id, 2), boxes_data(id, 3)}};
result_candidate_vec.push_back(rc);
}
}
}
auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
return rc_i.score > rc_j.score;
};
std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
int max_detections = 0;
// If pad_per_class is false, we always pad to max_total_size
if (!pad_per_class) {
max_detections =
std::min((int)result_candidate_vec.size(), total_size_per_batch);
per_batch_size = total_size_per_batch;
} else {
per_batch_size =
std::min(total_size_per_batch, max_size_per_class * num_classes);
max_detections =
std::min(per_batch_size, (int)result_candidate_vec.size());
}
final_valid_detections.push_back(max_detections);
int curr_total_size = max_detections;
int result_idx = 0;
// Pick the top max_detections values
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
// Add to final output vectors
if (clip_boxes) {
const float box_min = 0.0;
const float box_max = 1.0;
nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
} else {
nmsed_boxes[batch].push_back(next_candidate.box_coord[0]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[1]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[2]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[3]);
}
nmsed_scores[batch].push_back(next_candidate.score);
nmsed_classes[batch].push_back(next_candidate.class_idx);
curr_total_size--;
}
nmsed_boxes[batch].resize(per_batch_size * 4, 0);
nmsed_scores[batch].resize(per_batch_size, 0);
nmsed_classes[batch].resize(per_batch_size, 0);
if (pad_per_class) {
per_batch_size =
std::min(total_size_per_batch, max_size_per_class * num_classes);
}
Tensor* valid_detections_t = nullptr;
TensorShape valid_detections_shape({num_batches});
OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
&valid_detections_t));
auto valid_detections_flat = valid_detections_t->template flat<int>();
auto shard_result = [&](int begin, int end) {
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
SelectResultPerBatch(
nmsed_boxes[batch_idx], nmsed_scores[batch_idx],
nmsed_classes[batch_idx], result_candidate_vec[batch_idx],
final_valid_detections, batch_idx, total_size_per_batch,
pad_per_class, max_size_per_class * num_classes, clip_boxes,
per_batch_size);
valid_detections_flat(batch_idx) = final_valid_detections[batch_idx];
}
};
length = num_batches;
// Input data boxes_data, scores_data
input_bytes = length * num_boxes * 5;
output_bytes = length * num_boxes * 5;
compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
Eigen::TensorOpCost::MulCost<int>() * 2 +
Eigen::TensorOpCost::AddCost<float>() * 10 +
Eigen::TensorOpCost::MulCost<float>() * 6 +
Eigen::TensorOpCost::DivCost<float>()) *
length;
const Eigen::TensorOpCost cost_result(input_bytes, output_bytes,
compute_cycles);
d.parallelFor(length, cost_result, shard_result);
Tensor* nmsed_boxes_t = nullptr;
TensorShape boxes_shape({num_batches, per_batch_size, 4});
OP_REQUIRES_OK(context,
@ -477,23 +548,30 @@ void BatchedNonMaxSuppressionOp(
context->allocate_output(2, scores_shape, &nmsed_classes_t));
auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
Tensor* valid_detections_t = nullptr;
TensorShape valid_detections_shape({num_batches});
OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
&valid_detections_t));
auto valid_detections_flat = valid_detections_t->template flat<int>();
for (int i = 0; i < num_batches; ++i) {
valid_detections_flat(i) = final_valid_detections[i];
for (int j = 0; j < per_batch_size; ++j) {
nmsed_scores_flat(i * per_batch_size + j) = nmsed_scores[i][j];
nmsed_classes_flat(i * per_batch_size + j) = nmsed_classes[i][j];
auto shard_copy_result = [&](int begin, int end) {
for (int idx = begin; idx < end; ++idx) {
int batch_idx = idx / per_batch_size;
int j = idx % per_batch_size;
nmsed_scores_flat(idx) = nmsed_scores[batch_idx][j];
nmsed_classes_flat(idx) = nmsed_classes[batch_idx][j];
for (int k = 0; k < 4; ++k) {
nmsed_boxes_flat(i * per_batch_size * 4 + j * 4 + k) =
nmsed_boxes[i][j * 4 + k];
nmsed_boxes_flat(idx * 4 + k) = nmsed_boxes[batch_idx][j * 4 + k];
}
}
}
};
length = num_batches * per_batch_size;
// Input data boxes_data, scores_data
input_bytes = length * per_batch_size * 6;
output_bytes = length * per_batch_size * 6;
compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
Eigen::TensorOpCost::MulCost<int>() * 2 +
Eigen::TensorOpCost::AddCost<float>() * 10 +
Eigen::TensorOpCost::MulCost<float>() * 6 +
Eigen::TensorOpCost::DivCost<float>()) *
length;
const Eigen::TensorOpCost cost_copy_result(input_bytes, output_bytes,
compute_cycles);
d.parallelFor(length, cost_copy_result, shard_copy_result);
}
} // namespace
@ -563,9 +641,8 @@ class NonMaxSuppressionV2Op : public OpKernel {
iou_threshold.shape().DebugString()));
const T iou_threshold_val = iou_threshold.scalar<T>()();
OP_REQUIRES(context,
iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, &num_boxes);
@ -606,9 +683,8 @@ class NonMaxSuppressionV3Op : public OpKernel {
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
const T iou_threshold_val = iou_threshold.scalar<T>()();
OP_REQUIRES(context,
iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
@ -660,9 +736,8 @@ class NonMaxSuppressionV4Op : public OpKernel {
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
const T iou_threshold_val = iou_threshold.scalar<T>()();
OP_REQUIRES(context,
iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
@ -726,9 +801,8 @@ class NonMaxSuppressionV5Op : public OpKernel {
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
const T iou_threshold_val = iou_threshold.scalar<T>()();
OP_REQUIRES(context,
iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
iou_threshold_val <= static_cast<T>(1.0),
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);

View File

@ -0,0 +1,64 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
static Graph* BM_CombinedNonMaxSuppression(int batches, int box_num,
int class_num, int q) {
Graph* g = new Graph(OpRegistry::Global());
Tensor boxes(DT_FLOAT, TensorShape({batches, box_num, q, 4}));
boxes.flat<float>().setRandom();
Tensor scores(DT_FLOAT, TensorShape({batches, box_num, class_num}));
scores.flat<float>().setRandom();
Tensor max_output_size_per_class(100);
Tensor max_total_size(9000);
Tensor iou_threshold(float(0.3));
Tensor score_threshold(float(0.25));
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CombinedNonMaxSuppression")
.Input(test::graph::Constant(g, boxes))
.Input(test::graph::Constant(g, scores))
.Input(test::graph::Constant(g, max_output_size_per_class))
.Input(test::graph::Constant(g, max_total_size))
.Input(test::graph::Constant(g, iou_threshold))
.Input(test::graph::Constant(g, score_threshold))
.Attr("pad_per_class", false)
.Attr("clip_boxes", true)
.Finalize(g, &ret));
return g;
}
#define BM_CombinedNonMaxSuppressionDev(DEVICE, B, BN, CN, Q) \
static void BM_CombinedNMS_##DEVICE##_##B##_##BN##_##CN##_##Q(int iters) { \
testing::ItemsProcessed(iters* B); \
test::Benchmark(#DEVICE, BM_CombinedNonMaxSuppression(B, BN, CN, Q)) \
.Run(iters); \
} \
BENCHMARK(BM_CombinedNMS_##DEVICE##_##B##_##BN##_##CN##_##Q);
BM_CombinedNonMaxSuppressionDev(cpu, 1, 1917, 90, 1);
BM_CombinedNonMaxSuppressionDev(cpu, 28, 1917, 90, 1);
BM_CombinedNonMaxSuppressionDev(cpu, 32, 1917, 90, 1);
BM_CombinedNonMaxSuppressionDev(cpu, 64, 1917, 90, 1);
} // namespace tensorflow