Add support for int8 input type in tflite_driver.
PiperOrigin-RevId: 241600181
This commit is contained in:
parent
ebf22aecde
commit
d2ebdd72af
BIN
tensorflow/lite/testdata/add_quantized_int8.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/add_quantized_int8.bin
vendored
Normal file
Binary file not shown.
123
tensorflow/lite/testdata/add_quantized_int8.json
vendored
Normal file
123
tensorflow/lite/testdata/add_quantized_int8.json
vendored
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
{
|
||||||
|
version: 3,
|
||||||
|
operator_codes: [
|
||||||
|
{
|
||||||
|
}
|
||||||
|
],
|
||||||
|
subgraphs: [
|
||||||
|
{
|
||||||
|
tensors: [
|
||||||
|
{
|
||||||
|
shape: [
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
name: "add",
|
||||||
|
quantization: {
|
||||||
|
min: [
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
max: [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
scale: [
|
||||||
|
0.003922
|
||||||
|
],
|
||||||
|
zero_point: [
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
shape: [
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
type: "INT8",
|
||||||
|
name: "input",
|
||||||
|
quantization: {
|
||||||
|
min: [
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
max: [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
scale: [
|
||||||
|
0.003922
|
||||||
|
],
|
||||||
|
zero_point: [
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
shape: [
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
type: "INT8",
|
||||||
|
name: "output",
|
||||||
|
quantization: {
|
||||||
|
min: [
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
max: [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
scale: [
|
||||||
|
0.003922
|
||||||
|
],
|
||||||
|
zero_point: [
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
inputs: [
|
||||||
|
1
|
||||||
|
],
|
||||||
|
outputs: [
|
||||||
|
2
|
||||||
|
],
|
||||||
|
operators: [
|
||||||
|
{
|
||||||
|
inputs: [
|
||||||
|
1,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
outputs: [
|
||||||
|
0
|
||||||
|
],
|
||||||
|
builtin_options_type: "AddOptions",
|
||||||
|
builtin_options: {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
inputs: [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
outputs: [
|
||||||
|
2
|
||||||
|
],
|
||||||
|
builtin_options_type: "AddOptions",
|
||||||
|
builtin_options: {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
buffers: [
|
||||||
|
{
|
||||||
|
data: [
|
||||||
|
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -193,7 +193,10 @@ tf_cc_test(
|
|||||||
name = "tflite_driver_test",
|
name = "tflite_driver_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["tflite_driver_test.cc"],
|
srcs = ["tflite_driver_test.cc"],
|
||||||
data = ["//tensorflow/lite:testdata/multi_add.bin"],
|
data = [
|
||||||
|
"//tensorflow/lite:testdata/add_quantized_int8.bin",
|
||||||
|
"//tensorflow/lite:testdata/multi_add.bin",
|
||||||
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"tflite_not_portable_android",
|
"tflite_not_portable_android",
|
||||||
"tflite_not_portable_ios",
|
"tflite_not_portable_ios",
|
||||||
|
@ -68,6 +68,21 @@ inline string Join<uint8_t>(uint8_t* data, size_t len,
|
|||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Join a list of int8 data separated by a delimiter. Cast data to int before
|
||||||
|
// placing it in the string to prevent values from being treated like chars.
|
||||||
|
template <>
|
||||||
|
inline string Join<int8_t>(int8_t* data, size_t len, const string& delimiter) {
|
||||||
|
if (len == 0 || data == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
std::stringstream result;
|
||||||
|
result << static_cast<int>(data[0]);
|
||||||
|
for (int i = 1; i < len; i++) {
|
||||||
|
result << delimiter << static_cast<int>(data[i]);
|
||||||
|
}
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -80,6 +80,15 @@ inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
|
|||||||
return fields;
|
return fields;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline std::vector<int8_t> Split(const string& s, const string& delimiter) {
|
||||||
|
std::vector<int8_t> fields;
|
||||||
|
for (const auto& p : SplitToPos(s, delimiter)) {
|
||||||
|
fields.push_back(strtol(s.data() + p.first, nullptr, 10));
|
||||||
|
}
|
||||||
|
return fields;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline std::vector<bool> Split(const string& s, const string& delimiter) {
|
inline std::vector<bool> Split(const string& s, const string& delimiter) {
|
||||||
std::vector<bool> fields;
|
std::vector<bool> fields;
|
||||||
|
@ -14,8 +14,6 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "absl/strings/escaping.h"
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/lite/builtin_op_data.h"
|
#include "tensorflow/lite/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||||
@ -50,6 +48,10 @@ uint8_t Value(const TfLitePtrUnion& data, int index) {
|
|||||||
return data.uint8[index];
|
return data.uint8[index];
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
int8_t Value(const TfLitePtrUnion& data, int index) {
|
||||||
|
return data.int8[index];
|
||||||
|
}
|
||||||
|
template <>
|
||||||
bool Value(const TfLitePtrUnion& data, int index) {
|
bool Value(const TfLitePtrUnion& data, int index) {
|
||||||
return data.b[index];
|
return data.b[index];
|
||||||
}
|
}
|
||||||
@ -184,6 +186,8 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
|
|||||||
return TypedCheck<int64_t>(verbose, tensor);
|
return TypedCheck<int64_t>(verbose, tensor);
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
return TypedCheck<uint8_t>(verbose, tensor);
|
return TypedCheck<uint8_t>(verbose, tensor);
|
||||||
|
case kTfLiteInt8:
|
||||||
|
return TypedCheck<int8_t>(verbose, tensor);
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
return TypedCheck<bool>(verbose, tensor);
|
return TypedCheck<bool>(verbose, tensor);
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
@ -295,6 +299,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
|
|||||||
SetTensorData(values, &tensor->data);
|
SetTensorData(values, &tensor->data);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
const auto& values = testing::Split<int8_t>(csv_values, ",");
|
||||||
|
if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
|
||||||
|
SetTensorData(values, &tensor->data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case kTfLiteBool: {
|
case kTfLiteBool: {
|
||||||
const auto& values = testing::Split<bool>(csv_values, ",");
|
const auto& values = testing::Split<bool>(csv_values, ",");
|
||||||
if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
|
if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
|
||||||
@ -338,6 +348,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
expected_output_[id]->SetData<uint8_t>(csv_values);
|
expected_output_[id]->SetData<uint8_t>(csv_values);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
expected_output_[id]->SetData<int8_t>(csv_values);
|
||||||
|
break;
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
expected_output_[id]->SetData<bool>(csv_values);
|
expected_output_[id]->SetData<bool>(csv_values);
|
||||||
break;
|
break;
|
||||||
@ -402,7 +415,7 @@ string TfLiteDriver::ReadOutput(int id) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
return Join(tensor->data.uint8, num_elements, ",");
|
return Join(tensor->data.uint8, num_elements, ",");
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
return JoinDefault(tensor->data.int8, num_elements, ",");
|
return Join(tensor->data.int8, num_elements, ",");
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
return JoinDefault(tensor->data.b, num_elements, ",");
|
return JoinDefault(tensor->data.b, num_elements, ",");
|
||||||
default:
|
default:
|
||||||
|
@ -94,6 +94,32 @@ TEST(TfliteDriverTest, SingleAddOpTest) {
|
|||||||
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
|
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TfliteDriverTest, AddQuantizedInt8Test) {
|
||||||
|
std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
|
||||||
|
|
||||||
|
runner->SetModelBaseDir("tensorflow/lite");
|
||||||
|
runner->LoadModel("testdata/add_quantized_int8.bin");
|
||||||
|
ASSERT_TRUE(runner->IsValid());
|
||||||
|
|
||||||
|
ASSERT_THAT(runner->GetInputs(), ElementsAre(1));
|
||||||
|
ASSERT_THAT(runner->GetOutputs(), ElementsAre(2));
|
||||||
|
|
||||||
|
runner->ReshapeTensor(1, "1,2,2,1");
|
||||||
|
ASSERT_TRUE(runner->IsValid());
|
||||||
|
|
||||||
|
runner->AllocateTensors();
|
||||||
|
|
||||||
|
runner->SetInput(1, "1,1,1,1");
|
||||||
|
|
||||||
|
runner->SetExpectation(2, "3,3,3,3");
|
||||||
|
|
||||||
|
runner->Invoke();
|
||||||
|
ASSERT_TRUE(runner->IsValid());
|
||||||
|
|
||||||
|
ASSERT_TRUE(runner->CheckResults());
|
||||||
|
EXPECT_EQ(runner->ReadOutput(2), "3,3,3,3");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user