Add support for int8 input type in tflite_driver.

PiperOrigin-RevId: 241600181
This commit is contained in:
Yunlu Li 2019-04-02 14:07:37 -07:00 committed by TensorFlower Gardener
parent ebf22aecde
commit d2ebdd72af
7 changed files with 193 additions and 4 deletions

Binary file not shown.

View File

@ -0,0 +1,123 @@
{
version: 3,
operator_codes: [
{
}
],
subgraphs: [
{
tensors: [
{
shape: [
1,
8,
8,
3
],
name: "add",
quantization: {
min: [
0.0
],
max: [
1.0
],
scale: [
0.003922
],
zero_point: [
0
]
}
},
{
shape: [
1,
8,
8,
3
],
type: "INT8",
name: "input",
quantization: {
min: [
0.0
],
max: [
1.0
],
scale: [
0.003922
],
zero_point: [
0
]
}
},
{
shape: [
1,
8,
8,
3
],
type: "INT8",
name: "output",
quantization: {
min: [
0.0
],
max: [
1.0
],
scale: [
0.003922
],
zero_point: [
0
]
}
}
],
inputs: [
1
],
outputs: [
2
],
operators: [
{
inputs: [
1,
1
],
outputs: [
0
],
builtin_options_type: "AddOptions",
builtin_options: {
}
},
{
inputs: [
0,
1
],
outputs: [
2
],
builtin_options_type: "AddOptions",
builtin_options: {
}
}
]
}
],
buffers: [
{
data: [
]
}
]
}

View File

@ -193,7 +193,10 @@ tf_cc_test(
name = "tflite_driver_test",
size = "small",
srcs = ["tflite_driver_test.cc"],
data = ["//tensorflow/lite:testdata/multi_add.bin"],
data = [
"//tensorflow/lite:testdata/add_quantized_int8.bin",
"//tensorflow/lite:testdata/multi_add.bin",
],
tags = [
"tflite_not_portable_android",
"tflite_not_portable_ios",

View File

@ -68,6 +68,21 @@ inline string Join<uint8_t>(uint8_t* data, size_t len,
return result.str();
}
// Join a list of int8 data separated by a delimiter. Cast data to int before
// placing it in the string to prevent values from being treated like chars.
template <>
inline string Join<int8_t>(int8_t* data, size_t len, const string& delimiter) {
if (len == 0 || data == nullptr) {
return "";
}
std::stringstream result;
result << static_cast<int>(data[0]);
for (int i = 1; i < len; i++) {
result << delimiter << static_cast<int>(data[i]);
}
return result.str();
}
} // namespace testing
} // namespace tflite

View File

@ -80,6 +80,15 @@ inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
return fields;
}
template <>
inline std::vector<int8_t> Split(const string& s, const string& delimiter) {
std::vector<int8_t> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.push_back(strtol(s.data() + p.first, nullptr, 10));
}
return fields;
}
template <>
inline std::vector<bool> Split(const string& s, const string& delimiter) {
std::vector<bool> fields;

View File

@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/tflite_driver.h"
#include <iostream>
#include "absl/strings/escaping.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
@ -50,6 +48,10 @@ uint8_t Value(const TfLitePtrUnion& data, int index) {
return data.uint8[index];
}
template <>
int8_t Value(const TfLitePtrUnion& data, int index) {
return data.int8[index];
}
template <>
bool Value(const TfLitePtrUnion& data, int index) {
return data.b[index];
}
@ -184,6 +186,8 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
return TypedCheck<int64_t>(verbose, tensor);
case kTfLiteUInt8:
return TypedCheck<uint8_t>(verbose, tensor);
case kTfLiteInt8:
return TypedCheck<int8_t>(verbose, tensor);
case kTfLiteBool:
return TypedCheck<bool>(verbose, tensor);
case kTfLiteString:
@ -295,6 +299,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
SetTensorData(values, &tensor->data);
break;
}
case kTfLiteInt8: {
const auto& values = testing::Split<int8_t>(csv_values, ",");
if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
SetTensorData(values, &tensor->data);
break;
}
case kTfLiteBool: {
const auto& values = testing::Split<bool>(csv_values, ",");
if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
@ -338,6 +348,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
case kTfLiteUInt8:
expected_output_[id]->SetData<uint8_t>(csv_values);
break;
case kTfLiteInt8:
expected_output_[id]->SetData<int8_t>(csv_values);
break;
case kTfLiteBool:
expected_output_[id]->SetData<bool>(csv_values);
break;
@ -402,7 +415,7 @@ string TfLiteDriver::ReadOutput(int id) {
case kTfLiteUInt8:
return Join(tensor->data.uint8, num_elements, ",");
case kTfLiteInt8:
return JoinDefault(tensor->data.int8, num_elements, ",");
return Join(tensor->data.int8, num_elements, ",");
case kTfLiteBool:
return JoinDefault(tensor->data.b, num_elements, ",");
default:

View File

@ -94,6 +94,32 @@ TEST(TfliteDriverTest, SingleAddOpTest) {
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
}
TEST(TfliteDriverTest, AddQuantizedInt8Test) {
std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
runner->SetModelBaseDir("tensorflow/lite");
runner->LoadModel("testdata/add_quantized_int8.bin");
ASSERT_TRUE(runner->IsValid());
ASSERT_THAT(runner->GetInputs(), ElementsAre(1));
ASSERT_THAT(runner->GetOutputs(), ElementsAre(2));
runner->ReshapeTensor(1, "1,2,2,1");
ASSERT_TRUE(runner->IsValid());
runner->AllocateTensors();
runner->SetInput(1, "1,1,1,1");
runner->SetExpectation(2, "3,3,3,3");
runner->Invoke();
ASSERT_TRUE(runner->IsValid());
ASSERT_TRUE(runner->CheckResults());
EXPECT_EQ(runner->ReadOutput(2), "3,3,3,3");
}
} // namespace
} // namespace testing
} // namespace tflite