Add int16 support to Dequant.
PiperOrigin-RevId: 258917873
This commit is contained in:
parent
2ecc2fffad
commit
8fe40fa1ea
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
@ -64,6 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
||||
op_context.input->type == kTfLiteInt8 ||
|
||||
op_context.input->type == kTfLiteInt16 ||
|
||||
op_context.input->type == kTfLiteFloat16);
|
||||
|
||||
op_context.output->type = kTfLiteFloat32;
|
||||
@ -95,12 +97,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorData<float>(op_context.output));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
reference_integer_ops::Dequantize(
|
||||
reference_integer_ops::Dequantize<int8_t>(
|
||||
op_params, GetTensorShape(op_context.input),
|
||||
GetTensorData<int8_t>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<float>(op_context.output));
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
reference_integer_ops::Dequantize<int16_t>(
|
||||
op_params, GetTensorShape(op_context.input),
|
||||
GetTensorData<int16_t>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<float>(op_context.output));
|
||||
break;
|
||||
case kTfLiteFloat16: {
|
||||
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
|
||||
GetTensorData<TfLiteFloat16>(op_context.input));
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/memory/memory.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
@ -23,6 +24,15 @@ limitations under the License.
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
|
||||
TfLiteRegistration* Register_DEQUANTIZE();
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
@ -30,13 +40,17 @@ using ::testing::ElementsAreArray;
|
||||
class DequantizeOpModel : public SingleOpModel {
|
||||
public:
|
||||
DequantizeOpModel(TensorType type, std::initializer_list<int> shape,
|
||||
float scale, int32_t zero_point) {
|
||||
float scale, int32_t zero_point, int version) {
|
||||
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
|
||||
input_ = AddInput(input_tensor_data);
|
||||
output_ = AddOutput({TensorType_FLOAT32, shape});
|
||||
SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
|
||||
CreateDequantizeOptions(builder_).Union());
|
||||
|
||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||
BuiltinOperator_DEQUANTIZE, ops::builtin::Register_DEQUANTIZE(),
|
||||
version);
|
||||
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
@ -54,7 +68,7 @@ class DequantizeOpModel : public SingleOpModel {
|
||||
|
||||
TEST(DequantizeOpTest, Uint8) {
|
||||
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
|
||||
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127);
|
||||
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);
|
||||
|
||||
m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
|
||||
m.Invoke();
|
||||
@ -65,7 +79,7 @@ TEST(DequantizeOpTest, Uint8) {
|
||||
|
||||
TEST(DequantizeOpTest, Int8) {
|
||||
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
||||
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1);
|
||||
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2);
|
||||
|
||||
m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
|
||||
m.Invoke();
|
||||
@ -75,7 +89,7 @@ TEST(DequantizeOpTest, Int8) {
|
||||
}
|
||||
|
||||
TEST(DequantizeOpTest, Float16) {
|
||||
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0);
|
||||
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0, 3);
|
||||
|
||||
std::vector<Eigen::half> half{Eigen::half{-535.54f}, Eigen::half{-100.0f},
|
||||
Eigen::half{-1.0f}, Eigen::half{0.f},
|
||||
@ -88,5 +102,14 @@ TEST(DequantizeOpTest, Float16) {
|
||||
/*max_abs_error=*/0.1f)));
|
||||
}
|
||||
|
||||
TEST(DequantizeOpTest, Int16) {
|
||||
DequantizeOpModel m(TensorType_INT16, {2, 5}, 0.5, -1, 4);
|
||||
m.SetInput<int16_t>({-130, -127, -126, -125, -124, 123, 124, 125, 126, 130});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{-64.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 65.5})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -22,15 +22,16 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace reference_integer_ops {
|
||||
|
||||
template <typename T>
|
||||
inline void Dequantize(const tflite::DequantizationParams& op_params,
|
||||
const RuntimeShape& input_shape, const int8* input_data,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int32 zero_point = op_params.zero_point;
|
||||
const double scale = op_params.scale;
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
const int32 val = input_data[i];
|
||||
const int32 val = static_cast<int32>(input_data[i]);
|
||||
const float result = static_cast<float>(scale * (val - zero_point));
|
||||
output_data[i] = result;
|
||||
}
|
||||
|
@ -280,7 +280,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
/* max_version */ 4);
|
||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
|
||||
/* min_version */ 1,
|
||||
|
@ -117,10 +117,11 @@ struct TensorData {
|
||||
|
||||
class SingleOpResolver : public OpResolver {
|
||||
public:
|
||||
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
|
||||
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration,
|
||||
int version = 1)
|
||||
: op_(op), registration_(*registration) {
|
||||
registration_.builtin_code = static_cast<int32_t>(op);
|
||||
registration_.version = 1;
|
||||
registration_.version = version;
|
||||
}
|
||||
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
||||
int version) const override {
|
||||
|
@ -155,6 +155,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kWhere, 1}, "1.14.0"},
|
||||
{{OperatorType::kDequantize, 1}, "1.13.1"},
|
||||
{{OperatorType::kDequantize, 2}, "1.14.0"},
|
||||
{{OperatorType::kDequantize, 3}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
||||
{{OperatorType::kEqual, 1}, "1.14.0"},
|
||||
{{OperatorType::kEqual, 2}, "1.14.0"},
|
||||
|
@ -2237,6 +2237,11 @@ class Dequantize
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 3 supports signed int16 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt16 ||
|
||||
input_array.data_type == ArrayDataType::kFloat16) {
|
||||
return 3;
|
||||
}
|
||||
// Version 2 supports signed int8 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
|
@ -974,6 +974,42 @@ TEST_F(OperatorTest, VersioningFullyConnectedTest) {
|
||||
EXPECT_EQ(op->GetVersion(int8_signature), 4);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningDequantizeTest) {
|
||||
DequantizeOperator dequant_op;
|
||||
dequant_op.inputs = {"input"};
|
||||
dequant_op.outputs = {"output"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
|
||||
|
||||
Model int16_model;
|
||||
Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||
input_int16_array.data_type = ArrayDataType::kInt16;
|
||||
OperatorSignature int16_signature = {.op = &dequant_op,
|
||||
.model = &int16_model};
|
||||
EXPECT_EQ(op->GetVersion(int16_signature), 3);
|
||||
|
||||
Model float16_model;
|
||||
Array& input_float16_array =
|
||||
float16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||
input_float16_array.data_type = ArrayDataType::kFloat16;
|
||||
OperatorSignature float16_signature = {.op = &dequant_op,
|
||||
.model = &float16_model};
|
||||
EXPECT_EQ(op->GetVersion(float16_signature), 3);
|
||||
|
||||
Model int8_model;
|
||||
Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||
input_int8_array.data_type = ArrayDataType::kInt8;
|
||||
OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
|
||||
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
||||
|
||||
Model float_model;
|
||||
Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||
input_float_array.data_type = ArrayDataType::kFloat;
|
||||
OperatorSignature float_signature = {.op = &dequant_op,
|
||||
.model = &float_model};
|
||||
EXPECT_EQ(op->GetVersion(float_signature), 1);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningConv2DTest) {
|
||||
ConvOperator conv_op;
|
||||
conv_op.inputs = {"input", "filter"};
|
||||
|
Loading…
Reference in New Issue
Block a user