[Intel MKL] Optimize combinedNMS performance
This commit is contained in:
parent
fe03adf6e6
commit
0be4b608c0
tensorflow/core/kernels
@ -3245,6 +3245,22 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "non_max_suppression_op_benchmark_test",
|
||||
srcs = ["non_max_suppression_op_benchmark_test.cc"],
|
||||
deps = [
|
||||
":image",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "resize_bilinear_op_test",
|
||||
srcs = ["resize_bilinear_op_test.cc"],
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -33,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -284,14 +284,179 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
|
||||
}
|
||||
}
|
||||
|
||||
struct ResultCandidate {
|
||||
int box_index;
|
||||
float score;
|
||||
int class_idx;
|
||||
float box_coord[4];
|
||||
};
|
||||
|
||||
void DoNMS(int batch_idx, int class_idx, const float* boxes_data,
|
||||
const float* scores_data, int num_boxes, int q, int num_classes,
|
||||
const int size_per_class, const float score_threshold,
|
||||
const float iou_threshold,
|
||||
std::vector<ResultCandidate>& result_candidate_vec) {
|
||||
std::vector<float> class_scores_data;
|
||||
class_scores_data.reserve(num_boxes);
|
||||
std::vector<float> class_boxes_data;
|
||||
class_boxes_data.reserve(num_boxes * 4);
|
||||
|
||||
for (int box_idx = 0; box_idx < num_boxes; ++box_idx) {
|
||||
class_scores_data.push_back(scores_data[box_idx * num_classes + class_idx]);
|
||||
for (int cid = 0; cid < 4; ++cid) {
|
||||
if (q > 1) {
|
||||
class_boxes_data.push_back(
|
||||
boxes_data[(box_idx * q + class_idx) * 4 + cid]);
|
||||
} else {
|
||||
class_boxes_data.push_back(boxes_data[box_idx * 4 + cid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy class_boxes_data to a tensor
|
||||
TensorShape boxesShape({num_boxes, 4});
|
||||
Tensor boxes(DT_FLOAT, boxesShape);
|
||||
std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
|
||||
boxes.unaligned_flat<float>().data());
|
||||
|
||||
// Do NMS, get the candidate indices of form vector<int>
|
||||
// Data structure for selection candidate in NMS.
|
||||
struct Candidate {
|
||||
int box_index;
|
||||
float score;
|
||||
};
|
||||
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
|
||||
return bs_i.score > bs_j.score;
|
||||
};
|
||||
std::vector<Candidate> candidate_vector;
|
||||
for (int i = 0; i < class_scores_data.size(); ++i) {
|
||||
if (class_scores_data[i] > score_threshold) {
|
||||
candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> selected;
|
||||
std::vector<float> selected_boxes;
|
||||
Candidate next_candidate;
|
||||
|
||||
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
|
||||
const Tensor const_boxes = boxes;
|
||||
typename TTypes<float, 2>::ConstTensor boxes_data_1 =
|
||||
const_boxes.tensor<float, 2>();
|
||||
int candidate_idx = 0;
|
||||
float iou;
|
||||
while (selected.size() < size_per_class &&
|
||||
candidate_idx < candidate_vector.size()) {
|
||||
next_candidate = candidate_vector[candidate_idx++];
|
||||
|
||||
// Overlapping boxes are likely to have similar scores,
|
||||
// therefore we iterate through the previously selected boxes backwards
|
||||
// in order to see if `next_candidate` should be suppressed.
|
||||
bool should_select = true;
|
||||
for (int j = selected.size() - 1; j >= 0; --j) {
|
||||
iou = IOU<float>(boxes_data_1, next_candidate.box_index, selected[j]);
|
||||
if (iou > iou_threshold) {
|
||||
should_select = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_select) {
|
||||
// Add the selected box to the result candidate. Sorted by score
|
||||
int id = next_candidate.box_index;
|
||||
auto& rc =
|
||||
result_candidate_vec[selected.size() + size_per_class * class_idx];
|
||||
selected.push_back(next_candidate.box_index);
|
||||
rc.box_index = next_candidate.box_index;
|
||||
rc.score = next_candidate.score;
|
||||
rc.class_idx = class_idx;
|
||||
rc.box_coord[0] = boxes_data_1(id, 0);
|
||||
rc.box_coord[1] = boxes_data_1(id, 1);
|
||||
rc.box_coord[2] = boxes_data_1(id, 2);
|
||||
rc.box_coord[3] = boxes_data_1(id, 3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SelectResultPerBatch(std::vector<float>& nmsed_boxes,
|
||||
std::vector<float>& nmsed_scores,
|
||||
std::vector<float>& nmsed_classes,
|
||||
std::vector<ResultCandidate>& result_candidate_vec,
|
||||
std::vector<int>& final_valid_detections,
|
||||
const int batch_idx, int total_size_per_batch,
|
||||
bool pad_per_class, int max_size_per_batch,
|
||||
bool clip_boxes, int per_batch_size) {
|
||||
auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
|
||||
return rc_i.score > rc_j.score;
|
||||
};
|
||||
std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
|
||||
|
||||
int max_detections = 0;
|
||||
int result_candidate_size =
|
||||
std::count_if(result_candidate_vec.begin(), result_candidate_vec.end(),
|
||||
[](ResultCandidate rc) { return rc.box_index > -1; });
|
||||
// If pad_per_class is false, we always pad to max_total_size
|
||||
if (!pad_per_class) {
|
||||
max_detections = std::min(result_candidate_size, total_size_per_batch);
|
||||
} else {
|
||||
max_detections = std::min(per_batch_size, result_candidate_size);
|
||||
}
|
||||
|
||||
final_valid_detections[batch_idx] = max_detections;
|
||||
|
||||
int curr_total_size = max_detections;
|
||||
int result_idx = 0;
|
||||
// Pick the top max_detections values
|
||||
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
|
||||
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
|
||||
// Add to final output vectors
|
||||
if (clip_boxes) {
|
||||
const float box_min = 0.0;
|
||||
const float box_max = 1.0;
|
||||
nmsed_boxes.push_back(
|
||||
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
|
||||
nmsed_boxes.push_back(
|
||||
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
|
||||
nmsed_boxes.push_back(
|
||||
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
|
||||
nmsed_boxes.push_back(
|
||||
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
|
||||
} else {
|
||||
nmsed_boxes.push_back(next_candidate.box_coord[0]);
|
||||
nmsed_boxes.push_back(next_candidate.box_coord[1]);
|
||||
nmsed_boxes.push_back(next_candidate.box_coord[2]);
|
||||
nmsed_boxes.push_back(next_candidate.box_coord[3]);
|
||||
}
|
||||
nmsed_scores.push_back(next_candidate.score);
|
||||
nmsed_classes.push_back(next_candidate.class_idx);
|
||||
curr_total_size--;
|
||||
}
|
||||
|
||||
nmsed_boxes.resize(per_batch_size * 4, 0);
|
||||
nmsed_scores.resize(per_batch_size, 0);
|
||||
nmsed_classes.resize(per_batch_size, 0);
|
||||
}
|
||||
|
||||
void BatchedNonMaxSuppressionOp(
|
||||
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
|
||||
int num_boxes, const int max_size_per_class, const int total_size_per_batch,
|
||||
const float score_threshold, const float iou_threshold,
|
||||
bool pad_per_class = false, bool clip_boxes = true) {
|
||||
int q = inp_boxes.dim_size(2);
|
||||
int num_classes = inp_scores.dim_size(2);
|
||||
const int num_batches = inp_boxes.dim_size(0);
|
||||
int num_classes = inp_scores.dim_size(2);
|
||||
int q = inp_boxes.dim_size(2);
|
||||
|
||||
const float* scores_data =
|
||||
const_cast<float*>(inp_scores.flat<float>().data());
|
||||
const float* boxes_data = const_cast<float*>(inp_boxes.flat<float>().data());
|
||||
|
||||
int boxes_per_batch = num_boxes * q * 4;
|
||||
int scores_per_batch = num_boxes * num_classes;
|
||||
const int size_per_class = std::min(max_size_per_class, num_boxes);
|
||||
std::vector<std::vector<ResultCandidate>> result_candidate_vec(
|
||||
num_batches,
|
||||
std::vector<ResultCandidate>(size_per_class * num_classes,
|
||||
{-1, -1.0, -1, {0.0, 0.0, 0.0, 0.0}}));
|
||||
|
||||
// [num_batches, per_batch_size * 4]
|
||||
std::vector<std::vector<float>> nmsed_boxes(num_batches);
|
||||
@ -300,166 +465,72 @@ void BatchedNonMaxSuppressionOp(
|
||||
// [num_batches, per_batch_size]
|
||||
std::vector<std::vector<float>> nmsed_classes(num_batches);
|
||||
// [num_batches]
|
||||
std::vector<int> final_valid_detections;
|
||||
std::vector<int> final_valid_detections(num_batches);
|
||||
|
||||
auto shard_nms = [&](int begin, int end) {
|
||||
int boxes_per_batch = num_boxes * q * 4;
|
||||
int scores_per_batch = num_boxes * num_classes;
|
||||
for (int idx = begin; idx < end; ++idx) {
|
||||
int batch_idx = idx / num_classes;
|
||||
int class_idx = idx % num_classes;
|
||||
DoNMS(batch_idx, class_idx, boxes_data + boxes_per_batch * batch_idx,
|
||||
scores_data + scores_per_batch * batch_idx, num_boxes, q,
|
||||
num_classes, size_per_class, score_threshold, iou_threshold,
|
||||
result_candidate_vec[batch_idx]);
|
||||
}
|
||||
};
|
||||
|
||||
int length = num_batches * num_classes;
|
||||
// Input data boxes_data, scores_data
|
||||
int input_bytes = length * num_boxes * 5;
|
||||
int output_bytes = length * num_boxes * 5;
|
||||
int compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
|
||||
Eigen::TensorOpCost::MulCost<int>() * 2 +
|
||||
Eigen::TensorOpCost::AddCost<float>() * 10 +
|
||||
Eigen::TensorOpCost::MulCost<float>() * 6 +
|
||||
Eigen::TensorOpCost::DivCost<float>()) *
|
||||
length;
|
||||
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
|
||||
const CPUDevice& d = context->eigen_device<CPUDevice>();
|
||||
d.parallelFor(length, cost, shard_nms);
|
||||
|
||||
int per_batch_size = total_size_per_batch;
|
||||
|
||||
// perform non_max_suppression operation for each batch independently
|
||||
for (int batch = 0; batch < num_batches; ++batch) {
|
||||
// dims of per_batch_boxes [num_boxes, q, 4]
|
||||
Tensor per_batch_boxes = inp_boxes.Slice(batch, batch + 1);
|
||||
// dims of per_batch_scores [num_boxes, num_classes]
|
||||
Tensor per_batch_scores = inp_scores.Slice(batch, batch + 1);
|
||||
|
||||
struct ResultCandidate {
|
||||
int box_index;
|
||||
float score;
|
||||
int class_idx;
|
||||
float box_coord[4];
|
||||
};
|
||||
|
||||
std::vector<ResultCandidate> result_candidate_vec;
|
||||
|
||||
float* scores_data = per_batch_scores.unaligned_flat<float>().data();
|
||||
float* boxes_data = per_batch_boxes.unaligned_flat<float>().data();
|
||||
|
||||
// Iterate through all classes
|
||||
for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
|
||||
std::vector<float> class_scores_data;
|
||||
class_scores_data.reserve(num_boxes);
|
||||
std::vector<float> class_boxes_data;
|
||||
class_boxes_data.reserve(num_boxes * 4);
|
||||
|
||||
for (int box = 0; box < num_boxes; ++box) {
|
||||
// Get the scores per class
|
||||
// class_scores_data dim is [num_boxes].
|
||||
class_scores_data.push_back(scores_data[box * num_classes + class_idx]);
|
||||
for (int cid = 0; cid < 4; ++cid) {
|
||||
if (q > 1) {
|
||||
// Get the boxes per class. class_boxes_data dims is [num_boxes, 4]
|
||||
class_boxes_data.push_back(
|
||||
boxes_data[(box * q + class_idx) * 4 + cid]);
|
||||
} else {
|
||||
class_boxes_data.push_back(boxes_data[box * 4 + cid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy class_boxes_data to a tensor
|
||||
TensorShape boxesShape({num_boxes, 4});
|
||||
Tensor boxes(per_batch_boxes.dtype(), boxesShape);
|
||||
std::copy_n(class_boxes_data.begin(), class_boxes_data.size(),
|
||||
boxes.unaligned_flat<float>().data());
|
||||
|
||||
const int size_per_class = std::min(max_size_per_class, num_boxes);
|
||||
// Do NMS, get the candidate indices of form vector<int>
|
||||
// Data structure for selection candidate in NMS.
|
||||
struct Candidate {
|
||||
int box_index;
|
||||
float score;
|
||||
};
|
||||
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
|
||||
return bs_i.score > bs_j.score;
|
||||
};
|
||||
std::vector<Candidate> candidate_vector;
|
||||
for (int i = 0; i < class_scores_data.size(); ++i) {
|
||||
if (class_scores_data[i] > score_threshold) {
|
||||
candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> selected;
|
||||
Candidate next_candidate;
|
||||
|
||||
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
|
||||
const Tensor const_boxes = boxes;
|
||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
||||
const_boxes.tensor<float, 2>();
|
||||
int candidate_idx = 0;
|
||||
float iou;
|
||||
while (selected.size() < size_per_class &&
|
||||
candidate_idx < candidate_vector.size()) {
|
||||
next_candidate = candidate_vector[candidate_idx++];
|
||||
|
||||
// Overlapping boxes are likely to have similar scores,
|
||||
// therefore we iterate through the previously selected boxes backwards
|
||||
// in order to see if `next_candidate` should be suppressed.
|
||||
bool should_select = true;
|
||||
for (int j = selected.size() - 1; j >= 0; --j) {
|
||||
iou = IOU<float>(boxes_data, next_candidate.box_index, selected[j]);
|
||||
if (iou > iou_threshold) {
|
||||
should_select = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_select) {
|
||||
selected.push_back(next_candidate.box_index);
|
||||
// Add the selected box to the result candidate. Sorted by score
|
||||
int id = next_candidate.box_index;
|
||||
ResultCandidate rc = {next_candidate.box_index,
|
||||
next_candidate.score,
|
||||
class_idx,
|
||||
{boxes_data(id, 0), boxes_data(id, 1),
|
||||
boxes_data(id, 2), boxes_data(id, 3)}};
|
||||
result_candidate_vec.push_back(rc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
|
||||
return rc_i.score > rc_j.score;
|
||||
};
|
||||
std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
|
||||
|
||||
int max_detections = 0;
|
||||
// If pad_per_class is false, we always pad to max_total_size
|
||||
if (!pad_per_class) {
|
||||
max_detections =
|
||||
std::min((int)result_candidate_vec.size(), total_size_per_batch);
|
||||
per_batch_size = total_size_per_batch;
|
||||
} else {
|
||||
per_batch_size =
|
||||
std::min(total_size_per_batch, max_size_per_class * num_classes);
|
||||
max_detections =
|
||||
std::min(per_batch_size, (int)result_candidate_vec.size());
|
||||
}
|
||||
|
||||
final_valid_detections.push_back(max_detections);
|
||||
|
||||
int curr_total_size = max_detections;
|
||||
int result_idx = 0;
|
||||
// Pick the top max_detections values
|
||||
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
|
||||
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
|
||||
// Add to final output vectors
|
||||
if (clip_boxes) {
|
||||
const float box_min = 0.0;
|
||||
const float box_max = 1.0;
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
|
||||
} else {
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[0]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[1]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[2]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[3]);
|
||||
}
|
||||
nmsed_scores[batch].push_back(next_candidate.score);
|
||||
nmsed_classes[batch].push_back(next_candidate.class_idx);
|
||||
curr_total_size--;
|
||||
}
|
||||
|
||||
nmsed_boxes[batch].resize(per_batch_size * 4, 0);
|
||||
nmsed_scores[batch].resize(per_batch_size, 0);
|
||||
nmsed_classes[batch].resize(per_batch_size, 0);
|
||||
if (pad_per_class) {
|
||||
per_batch_size =
|
||||
std::min(total_size_per_batch, max_size_per_class * num_classes);
|
||||
}
|
||||
|
||||
Tensor* valid_detections_t = nullptr;
|
||||
TensorShape valid_detections_shape({num_batches});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
|
||||
&valid_detections_t));
|
||||
auto valid_detections_flat = valid_detections_t->template flat<int>();
|
||||
|
||||
auto shard_result = [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
SelectResultPerBatch(
|
||||
nmsed_boxes[batch_idx], nmsed_scores[batch_idx],
|
||||
nmsed_classes[batch_idx], result_candidate_vec[batch_idx],
|
||||
final_valid_detections, batch_idx, total_size_per_batch,
|
||||
pad_per_class, max_size_per_class * num_classes, clip_boxes,
|
||||
per_batch_size);
|
||||
valid_detections_flat(batch_idx) = final_valid_detections[batch_idx];
|
||||
}
|
||||
};
|
||||
length = num_batches;
|
||||
// Input data boxes_data, scores_data
|
||||
input_bytes = length * num_boxes * 5;
|
||||
output_bytes = length * num_boxes * 5;
|
||||
compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
|
||||
Eigen::TensorOpCost::MulCost<int>() * 2 +
|
||||
Eigen::TensorOpCost::AddCost<float>() * 10 +
|
||||
Eigen::TensorOpCost::MulCost<float>() * 6 +
|
||||
Eigen::TensorOpCost::DivCost<float>()) *
|
||||
length;
|
||||
const Eigen::TensorOpCost cost_result(input_bytes, output_bytes,
|
||||
compute_cycles);
|
||||
d.parallelFor(length, cost_result, shard_result);
|
||||
|
||||
Tensor* nmsed_boxes_t = nullptr;
|
||||
TensorShape boxes_shape({num_batches, per_batch_size, 4});
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -477,23 +548,30 @@ void BatchedNonMaxSuppressionOp(
|
||||
context->allocate_output(2, scores_shape, &nmsed_classes_t));
|
||||
auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
|
||||
|
||||
Tensor* valid_detections_t = nullptr;
|
||||
TensorShape valid_detections_shape({num_batches});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
|
||||
&valid_detections_t));
|
||||
auto valid_detections_flat = valid_detections_t->template flat<int>();
|
||||
|
||||
for (int i = 0; i < num_batches; ++i) {
|
||||
valid_detections_flat(i) = final_valid_detections[i];
|
||||
for (int j = 0; j < per_batch_size; ++j) {
|
||||
nmsed_scores_flat(i * per_batch_size + j) = nmsed_scores[i][j];
|
||||
nmsed_classes_flat(i * per_batch_size + j) = nmsed_classes[i][j];
|
||||
auto shard_copy_result = [&](int begin, int end) {
|
||||
for (int idx = begin; idx < end; ++idx) {
|
||||
int batch_idx = idx / per_batch_size;
|
||||
int j = idx % per_batch_size;
|
||||
nmsed_scores_flat(idx) = nmsed_scores[batch_idx][j];
|
||||
nmsed_classes_flat(idx) = nmsed_classes[batch_idx][j];
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
nmsed_boxes_flat(i * per_batch_size * 4 + j * 4 + k) =
|
||||
nmsed_boxes[i][j * 4 + k];
|
||||
nmsed_boxes_flat(idx * 4 + k) = nmsed_boxes[batch_idx][j * 4 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
length = num_batches * per_batch_size;
|
||||
// Input data boxes_data, scores_data
|
||||
input_bytes = length * per_batch_size * 6;
|
||||
output_bytes = length * per_batch_size * 6;
|
||||
compute_cycles = (Eigen::TensorOpCost::AddCost<int>() * 5 +
|
||||
Eigen::TensorOpCost::MulCost<int>() * 2 +
|
||||
Eigen::TensorOpCost::AddCost<float>() * 10 +
|
||||
Eigen::TensorOpCost::MulCost<float>() * 6 +
|
||||
Eigen::TensorOpCost::DivCost<float>()) *
|
||||
length;
|
||||
const Eigen::TensorOpCost cost_copy_result(input_bytes, output_bytes,
|
||||
compute_cycles);
|
||||
d.parallelFor(length, cost_copy_result, shard_copy_result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -563,9 +641,8 @@ class NonMaxSuppressionV2Op : public OpKernel {
|
||||
iou_threshold.shape().DebugString()));
|
||||
const T iou_threshold_val = iou_threshold.scalar<T>()();
|
||||
|
||||
OP_REQUIRES(context,
|
||||
iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||
int num_boxes = 0;
|
||||
ParseAndCheckBoxSizes(context, boxes, &num_boxes);
|
||||
@ -606,9 +683,8 @@ class NonMaxSuppressionV3Op : public OpKernel {
|
||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString()));
|
||||
const T iou_threshold_val = iou_threshold.scalar<T>()();
|
||||
OP_REQUIRES(context,
|
||||
iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||
// score_threshold: scalar
|
||||
const Tensor& score_threshold = context->input(4);
|
||||
@ -660,9 +736,8 @@ class NonMaxSuppressionV4Op : public OpKernel {
|
||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString()));
|
||||
const T iou_threshold_val = iou_threshold.scalar<T>()();
|
||||
OP_REQUIRES(context,
|
||||
iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||
// score_threshold: scalar
|
||||
const Tensor& score_threshold = context->input(4);
|
||||
@ -726,9 +801,8 @@ class NonMaxSuppressionV5Op : public OpKernel {
|
||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString()));
|
||||
const T iou_threshold_val = iou_threshold.scalar<T>()();
|
||||
OP_REQUIRES(context,
|
||||
iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
OP_REQUIRES(context, iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
iou_threshold_val <= static_cast<T>(1.0),
|
||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||
// score_threshold: scalar
|
||||
const Tensor& score_threshold = context->input(4);
|
||||
|
@ -0,0 +1,64 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static Graph* BM_CombinedNonMaxSuppression(int batches, int box_num,
|
||||
int class_num, int q) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor boxes(DT_FLOAT, TensorShape({batches, box_num, q, 4}));
|
||||
boxes.flat<float>().setRandom();
|
||||
Tensor scores(DT_FLOAT, TensorShape({batches, box_num, class_num}));
|
||||
scores.flat<float>().setRandom();
|
||||
|
||||
Tensor max_output_size_per_class(100);
|
||||
Tensor max_total_size(9000);
|
||||
Tensor iou_threshold(float(0.3));
|
||||
Tensor score_threshold(float(0.25));
|
||||
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CombinedNonMaxSuppression")
|
||||
.Input(test::graph::Constant(g, boxes))
|
||||
.Input(test::graph::Constant(g, scores))
|
||||
.Input(test::graph::Constant(g, max_output_size_per_class))
|
||||
.Input(test::graph::Constant(g, max_total_size))
|
||||
.Input(test::graph::Constant(g, iou_threshold))
|
||||
.Input(test::graph::Constant(g, score_threshold))
|
||||
.Attr("pad_per_class", false)
|
||||
.Attr("clip_boxes", true)
|
||||
.Finalize(g, &ret));
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_CombinedNonMaxSuppressionDev(DEVICE, B, BN, CN, Q) \
|
||||
static void BM_CombinedNMS_##DEVICE##_##B##_##BN##_##CN##_##Q(int iters) { \
|
||||
testing::ItemsProcessed(iters* B); \
|
||||
test::Benchmark(#DEVICE, BM_CombinedNonMaxSuppression(B, BN, CN, Q)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_CombinedNMS_##DEVICE##_##B##_##BN##_##CN##_##Q);
|
||||
|
||||
BM_CombinedNonMaxSuppressionDev(cpu, 1, 1917, 90, 1);
|
||||
BM_CombinedNonMaxSuppressionDev(cpu, 28, 1917, 90, 1);
|
||||
BM_CombinedNonMaxSuppressionDev(cpu, 32, 1917, 90, 1);
|
||||
BM_CombinedNonMaxSuppressionDev(cpu, 64, 1917, 90, 1);
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user