Add support for int8 input type in tflite_driver.
PiperOrigin-RevId: 241600181
This commit is contained in:
parent
ebf22aecde
commit
d2ebdd72af
BIN
tensorflow/lite/testdata/add_quantized_int8.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/add_quantized_int8.bin
vendored
Normal file
Binary file not shown.
123
tensorflow/lite/testdata/add_quantized_int8.json
vendored
Normal file
123
tensorflow/lite/testdata/add_quantized_int8.json
vendored
Normal file
@ -0,0 +1,123 @@
|
||||
{
|
||||
version: 3,
|
||||
operator_codes: [
|
||||
{
|
||||
}
|
||||
],
|
||||
subgraphs: [
|
||||
{
|
||||
tensors: [
|
||||
{
|
||||
shape: [
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
3
|
||||
],
|
||||
name: "add",
|
||||
quantization: {
|
||||
min: [
|
||||
0.0
|
||||
],
|
||||
max: [
|
||||
1.0
|
||||
],
|
||||
scale: [
|
||||
0.003922
|
||||
],
|
||||
zero_point: [
|
||||
0
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
3
|
||||
],
|
||||
type: "INT8",
|
||||
name: "input",
|
||||
quantization: {
|
||||
min: [
|
||||
0.0
|
||||
],
|
||||
max: [
|
||||
1.0
|
||||
],
|
||||
scale: [
|
||||
0.003922
|
||||
],
|
||||
zero_point: [
|
||||
0
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
3
|
||||
],
|
||||
type: "INT8",
|
||||
name: "output",
|
||||
quantization: {
|
||||
min: [
|
||||
0.0
|
||||
],
|
||||
max: [
|
||||
1.0
|
||||
],
|
||||
scale: [
|
||||
0.003922
|
||||
],
|
||||
zero_point: [
|
||||
0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
inputs: [
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
2
|
||||
],
|
||||
operators: [
|
||||
{
|
||||
inputs: [
|
||||
1,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
0
|
||||
],
|
||||
builtin_options_type: "AddOptions",
|
||||
builtin_options: {
|
||||
}
|
||||
},
|
||||
{
|
||||
inputs: [
|
||||
0,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
2
|
||||
],
|
||||
builtin_options_type: "AddOptions",
|
||||
builtin_options: {
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
buffers: [
|
||||
{
|
||||
data: [
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -193,7 +193,10 @@ tf_cc_test(
|
||||
name = "tflite_driver_test",
|
||||
size = "small",
|
||||
srcs = ["tflite_driver_test.cc"],
|
||||
data = ["//tensorflow/lite:testdata/multi_add.bin"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add_quantized_int8.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
"tflite_not_portable_ios",
|
||||
|
@ -68,6 +68,21 @@ inline string Join<uint8_t>(uint8_t* data, size_t len,
|
||||
return result.str();
|
||||
}
|
||||
|
||||
// Join a list of int8 data separated by a delimiter. Cast data to int before
|
||||
// placing it in the string to prevent values from being treated like chars.
|
||||
template <>
|
||||
inline string Join<int8_t>(int8_t* data, size_t len, const string& delimiter) {
|
||||
if (len == 0 || data == nullptr) {
|
||||
return "";
|
||||
}
|
||||
std::stringstream result;
|
||||
result << static_cast<int>(data[0]);
|
||||
for (int i = 1; i < len; i++) {
|
||||
result << delimiter << static_cast<int>(data[i]);
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -80,6 +80,15 @@ inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int8_t> Split(const string& s, const string& delimiter) {
|
||||
std::vector<int8_t> fields;
|
||||
for (const auto& p : SplitToPos(s, delimiter)) {
|
||||
fields.push_back(strtol(s.data() + p.first, nullptr, 10));
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<bool> Split(const string& s, const string& delimiter) {
|
||||
std::vector<bool> fields;
|
||||
|
@ -14,8 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
@ -50,6 +48,10 @@ uint8_t Value(const TfLitePtrUnion& data, int index) {
|
||||
return data.uint8[index];
|
||||
}
|
||||
template <>
|
||||
int8_t Value(const TfLitePtrUnion& data, int index) {
|
||||
return data.int8[index];
|
||||
}
|
||||
template <>
|
||||
bool Value(const TfLitePtrUnion& data, int index) {
|
||||
return data.b[index];
|
||||
}
|
||||
@ -184,6 +186,8 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
|
||||
return TypedCheck<int64_t>(verbose, tensor);
|
||||
case kTfLiteUInt8:
|
||||
return TypedCheck<uint8_t>(verbose, tensor);
|
||||
case kTfLiteInt8:
|
||||
return TypedCheck<int8_t>(verbose, tensor);
|
||||
case kTfLiteBool:
|
||||
return TypedCheck<bool>(verbose, tensor);
|
||||
case kTfLiteString:
|
||||
@ -295,6 +299,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
|
||||
SetTensorData(values, &tensor->data);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
const auto& values = testing::Split<int8_t>(csv_values, ",");
|
||||
if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
|
||||
SetTensorData(values, &tensor->data);
|
||||
break;
|
||||
}
|
||||
case kTfLiteBool: {
|
||||
const auto& values = testing::Split<bool>(csv_values, ",");
|
||||
if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
|
||||
@ -338,6 +348,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
|
||||
case kTfLiteUInt8:
|
||||
expected_output_[id]->SetData<uint8_t>(csv_values);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
expected_output_[id]->SetData<int8_t>(csv_values);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
expected_output_[id]->SetData<bool>(csv_values);
|
||||
break;
|
||||
@ -402,7 +415,7 @@ string TfLiteDriver::ReadOutput(int id) {
|
||||
case kTfLiteUInt8:
|
||||
return Join(tensor->data.uint8, num_elements, ",");
|
||||
case kTfLiteInt8:
|
||||
return JoinDefault(tensor->data.int8, num_elements, ",");
|
||||
return Join(tensor->data.int8, num_elements, ",");
|
||||
case kTfLiteBool:
|
||||
return JoinDefault(tensor->data.b, num_elements, ",");
|
||||
default:
|
||||
|
@ -94,6 +94,32 @@ TEST(TfliteDriverTest, SingleAddOpTest) {
|
||||
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
|
||||
}
|
||||
|
||||
TEST(TfliteDriverTest, AddQuantizedInt8Test) {
|
||||
std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
|
||||
|
||||
runner->SetModelBaseDir("tensorflow/lite");
|
||||
runner->LoadModel("testdata/add_quantized_int8.bin");
|
||||
ASSERT_TRUE(runner->IsValid());
|
||||
|
||||
ASSERT_THAT(runner->GetInputs(), ElementsAre(1));
|
||||
ASSERT_THAT(runner->GetOutputs(), ElementsAre(2));
|
||||
|
||||
runner->ReshapeTensor(1, "1,2,2,1");
|
||||
ASSERT_TRUE(runner->IsValid());
|
||||
|
||||
runner->AllocateTensors();
|
||||
|
||||
runner->SetInput(1, "1,1,1,1");
|
||||
|
||||
runner->SetExpectation(2, "3,3,3,3");
|
||||
|
||||
runner->Invoke();
|
||||
ASSERT_TRUE(runner->IsValid());
|
||||
|
||||
ASSERT_TRUE(runner->CheckResults());
|
||||
EXPECT_EQ(runner->ReadOutput(2), "3,3,3,3");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user