This commit is contained in:
Koan-Sin Tan 2020-03-11 15:04:49 +08:00
parent 2c89d8cd5e
commit 06453ac4db
8 changed files with 64 additions and 26 deletions

View File

@ -29,6 +29,7 @@ cc_binary(
}),
deps = [
":bitmap_helpers",
"//tensorflow/lite/c:common",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
@ -89,6 +90,7 @@ cc_test(
],
deps = [
":bitmap_helpers",
"//tensorflow/lite/c:common",
"@com_google_googletest//:gtest",
],
)

View File

@ -31,10 +31,12 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
int wanted_channels, Settings* s);
// explicit instantiation
template void resize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int, int,
int, Settings*);
template void resize<float>(float*, unsigned char*, int, int, int, int, int,
int, Settings*);
template void resize<int8_t>(int8_t*, unsigned char*, int, int, int, int, int,
int, Settings*);
template void resize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int, int,
int, Settings*);
} // namespace label_image
} // namespace tflite

View File

@ -83,10 +83,19 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
auto output_number_of_pixels = wanted_height * wanted_width * wanted_channels;
for (int i = 0; i < output_number_of_pixels; i++) {
if (s->input_floating)
out[i] = (output[i] - s->input_mean) / s->input_std;
else
out[i] = (uint8_t)output[i];
switch (s->input_type) {
case kTfLiteFloat32:
out[i] = (output[i] - s->input_mean) / s->input_std;
break;
case kTfLiteInt8:
out[i] = static_cast<int8_t>(output[i] - 128);
break;
case kTfLiteUInt8:
out[i] = static_cast<uint8_t>(output[i]);
break;
default:
break;
}
}
}

View File

@ -24,13 +24,17 @@ namespace label_image {
template <class T>
void get_top_n(T* prediction, int prediction_size, size_t num_results,
float threshold, std::vector<std::pair<float, int>>* top_results,
bool input_floating);
TfLiteType input_type);
// explicit instantiation so that we can use them otherwhere
template void get_top_n<uint8_t>(uint8_t*, int, size_t, float,
std::vector<std::pair<float, int>>*, bool);
template void get_top_n<float>(float*, int, size_t, float,
std::vector<std::pair<float, int>>*, bool);
std::vector<std::pair<float, int>>*, TfLiteType);
template void get_top_n<int8_t>(int8_t*, int, size_t, float,
std::vector<std::pair<float, int>>*,
TfLiteType);
template void get_top_n<uint8_t>(uint8_t*, int, size_t, float,
std::vector<std::pair<float, int>>*,
TfLiteType);
} // namespace label_image
} // namespace tflite

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <functional>
#include <queue>
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace label_image {
@ -30,19 +32,29 @@ extern bool input_floating;
template <class T>
void get_top_n(T* prediction, int prediction_size, size_t num_results,
float threshold, std::vector<std::pair<float, int>>* top_results,
bool input_floating) {
TfLiteType input_type) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
std::greater<std::pair<float, int>>>
top_result_pq;
const long count = prediction_size; // NOLINT(runtime/int)
float value = 0.0;
for (int i = 0; i < count; ++i) {
float value;
if (input_floating)
value = prediction[i];
else
value = prediction[i] / 255.0;
switch (input_type) {
case kTfLiteFloat32:
value = prediction[i];
break;
case kTfLiteInt8:
value = (prediction[i] + 128) / 256.0;
break;
case kTfLiteUInt8:
value = prediction[i] / 255.0;
break;
default:
break;
}
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {

View File

@ -236,13 +236,18 @@ void RunInference(Settings* s) {
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
switch (interpreter->tensor(input)->type) {
s->input_type = interpreter->tensor(input)->type;
switch (s->input_type) {
case kTfLiteFloat32:
s->input_floating = true;
resize<float>(interpreter->typed_tensor<float>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
case kTfLiteInt8:
resize<int8_t>(interpreter->typed_tensor<int8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
case kTfLiteUInt8:
resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
@ -253,7 +258,6 @@ void RunInference(Settings* s) {
<< interpreter->tensor(input)->type << " yet";
exit(-1);
}
auto profiler =
absl::make_unique<profiling::Profiler>(s->max_profiling_buffer_entries);
interpreter->SetProfiler(profiler.get());
@ -305,12 +309,18 @@ void RunInference(Settings* s) {
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
s->number_of_results, threshold, &top_results, true);
s->number_of_results, threshold, &top_results,
s->input_type);
break;
case kTfLiteInt8:
get_top_n<int8_t>(interpreter->typed_output_tensor<int8_t>(0),
output_size, s->number_of_results, threshold,
&top_results, s->input_type);
break;
case kTfLiteUInt8:
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
output_size, s->number_of_results, threshold,
&top_results, false);
&top_results, s->input_type);
break;
default:
LOG(FATAL) << "cannot handle output type "

View File

@ -26,7 +26,7 @@ struct Settings {
bool verbose = false;
bool accel = false;
bool old_accel = false;
bool input_floating = false;
TfLiteType input_type = kTfLiteFloat32;
bool profiling = false;
bool allow_fp16 = false;
bool gl_backend = false;
@ -38,7 +38,6 @@ struct Settings {
tflite::FlatBufferModel* model;
string input_bmp_name = "./grace_hopper.bmp";
string labels_file_name = "./labels.txt";
string input_layer_type = "uint8_t";
int number_of_threads = 4;
int number_of_results = 5;
int max_profiling_buffer_entries = 1024;

View File

@ -37,15 +37,15 @@ TEST(LabelImageTest, GraceHopper) {
std::vector<uint8_t> output(606 * 517 * 3);
resize<uint8_t>(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s);
ASSERT_EQ(output[0], 0x15);
ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11);
ASSERT_EQ(output[0], 0x0);
ASSERT_EQ(output[214 * 214 * 3 - 1], 0x0);
}
TEST(LabelImageTest, GetTopN) {
uint8_t in[] = {1, 1, 2, 2, 4, 4, 16, 32, 128, 64};
std::vector<std::pair<float, int>> top_results;
get_top_n<uint8_t>(in, 10, 5, 0.025, &top_results, false);
get_top_n<uint8_t>(in, 10, 5, 0.025, &top_results, kTfLiteUInt8);
ASSERT_EQ(top_results.size(), 4);
ASSERT_EQ(top_results[0].second, 8);
}