From 8fe40fa1ea2990bd35940106b443fbf88716202f Mon Sep 17 00:00:00 2001 From: Jian Li Date: Fri, 19 Jul 2019 00:17:43 -0700 Subject: [PATCH] Add int16 support to Dequant. PiperOrigin-RevId: 258917873 --- tensorflow/lite/kernels/dequantize.cc | 11 +++++- tensorflow/lite/kernels/dequantize_test.cc | 31 +++++++++++++--- .../reference/integer_ops/dequantize.h | 5 +-- tensorflow/lite/kernels/register.cc | 2 +- tensorflow/lite/kernels/test_util.h | 5 +-- tensorflow/lite/toco/tflite/op_version.cc | 1 + tensorflow/lite/toco/tflite/operator.cc | 5 +++ tensorflow/lite/toco/tflite/operator_test.cc | 36 +++++++++++++++++++ 8 files changed, 86 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/kernels/dequantize.cc b/tensorflow/lite/kernels/dequantize.cc index 7c17cae7607..db7e23e6fa0 100644 --- a/tensorflow/lite/kernels/dequantize.cc +++ b/tensorflow/lite/kernels/dequantize.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include "third_party/eigen3/Eigen/Core" @@ -64,6 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || op_context.input->type == kTfLiteInt8 || + op_context.input->type == kTfLiteInt16 || op_context.input->type == kTfLiteFloat16); op_context.output->type = kTfLiteFloat32; @@ -95,12 +97,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(op_context.output)); break; case kTfLiteInt8: - reference_integer_ops::Dequantize( + reference_integer_ops::Dequantize( op_params, GetTensorShape(op_context.input), GetTensorData(op_context.input), GetTensorShape(op_context.output), GetTensorData(op_context.output)); break; + case kTfLiteInt16: + reference_integer_ops::Dequantize( + op_params, GetTensorShape(op_context.input), + GetTensorData(op_context.input), + GetTensorShape(op_context.output), + GetTensorData(op_context.output)); + break; case kTfLiteFloat16: { const Eigen::half* half_data = reinterpret_cast( GetTensorData(op_context.input)); diff --git a/tensorflow/lite/kernels/dequantize_test.cc b/tensorflow/lite/kernels/dequantize_test.cc index df76e27bbaa..f55a23e138d 100644 --- a/tensorflow/lite/kernels/dequantize_test.cc +++ b/tensorflow/lite/kernels/dequantize_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -23,6 +24,15 @@ limitations under the License. #include "tensorflow/lite/model.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_DEQUANTIZE(); + +} // namespace builtin +} // namespace ops + namespace { using ::testing::ElementsAreArray; @@ -30,13 +40,17 @@ using ::testing::ElementsAreArray; class DequantizeOpModel : public SingleOpModel { public: DequantizeOpModel(TensorType type, std::initializer_list shape, - float scale, int32_t zero_point) { + float scale, int32_t zero_point, int version) { const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point}; input_ = AddInput(input_tensor_data); output_ = AddOutput({TensorType_FLOAT32, shape}); SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, CreateDequantizeOptions(builder_).Union()); + resolver_ = absl::make_unique( + BuiltinOperator_DEQUANTIZE, ops::builtin::Register_DEQUANTIZE(), + version); + BuildInterpreter({GetShape(input_)}); } @@ -54,7 +68,7 @@ class DequantizeOpModel : public SingleOpModel { TEST(DequantizeOpTest, Uint8) { // [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8 - DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127); + DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1); m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); m.Invoke(); @@ -65,7 +79,7 @@ TEST(DequantizeOpTest, Uint8) { TEST(DequantizeOpTest, Int8) { // [-63.5, 64] -> scale=0.5, zero_point=1 for INT8 - DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1); + DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2); m.SetInput({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}); m.Invoke(); @@ -75,7 +89,7 @@ TEST(DequantizeOpTest, Int8) { } TEST(DequantizeOpTest, Float16) { - DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0); + DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0, 3); std::vector half{Eigen::half{-535.54f}, Eigen::half{-100.0f}, Eigen::half{-1.0f}, Eigen::half{0.f}, @@ -88,5 +102,14 @@ TEST(DequantizeOpTest, Float16) { /*max_abs_error=*/0.1f))); } +TEST(DequantizeOpTest, Int16) { + DequantizeOpModel m(TensorType_INT16, {2, 5}, 0.5, -1, 4); + m.SetInput({-130, -127, -126, -125, -124, 123, 124, 125, 126, 130}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-64.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 65.5}))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h b/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h index 03dcb6c220d..ae846faf251 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h @@ -22,15 +22,16 @@ limitations under the License. namespace tflite { namespace reference_integer_ops { +template inline void Dequantize(const tflite::DequantizationParams& op_params, - const RuntimeShape& input_shape, const int8* input_data, + const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, float* output_data) { const int32 zero_point = op_params.zero_point; const double scale = op_params.scale; const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { - const int32 val = input_data[i]; + const int32 val = static_cast(input_data[i]); const float result = static_cast(scale * (val - zero_point)); output_data[i] = result; } diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index f17c445a063..bd2643aaa64 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -280,7 +280,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 4); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), /* min_version */ 1, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index eb021c68197..1faae708340 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -117,10 +117,11 @@ struct TensorData { class SingleOpResolver : public OpResolver { public: - SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration) + SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration, + int version = 1) : op_(op), registration_(*registration) { registration_.builtin_code = static_cast(op); - registration_.version = 1; + registration_.version = version; } const TfLiteRegistration* FindOp(BuiltinOperator op, int version) const override { diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index ddd5f598eec..1937f3efeb8 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -155,6 +155,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kWhere, 1}, "1.14.0"}, {{OperatorType::kDequantize, 1}, "1.13.1"}, {{OperatorType::kDequantize, 2}, "1.14.0"}, + {{OperatorType::kDequantize, 3}, kPendingReleaseOpVersion}, {{OperatorType::kReverseSequence, 1}, "1.14.0"}, {{OperatorType::kEqual, 1}, "1.14.0"}, {{OperatorType::kEqual, 2}, "1.14.0"}, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 03313aaca91..b064ea396e1 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -2237,6 +2237,11 @@ class Dequantize int GetVersion(const OperatorSignature& op_signature) const override { const string& input_name = op_signature.op->inputs[0]; const Array& input_array = op_signature.model->GetArray(input_name); + // Version 3 supports signed int16 input types. + if (input_array.data_type == ArrayDataType::kInt16 || + input_array.data_type == ArrayDataType::kFloat16) { + return 3; + } // Version 2 supports signed int8 input types. if (input_array.data_type == ArrayDataType::kInt8) { return 2; diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index ffa68738d4b..3b007cb2514 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -974,6 +974,42 @@ TEST_F(OperatorTest, VersioningFullyConnectedTest) { EXPECT_EQ(op->GetVersion(int8_signature), 4); } +TEST_F(OperatorTest, VersioningDequantizeTest) { + DequantizeOperator dequant_op; + dequant_op.inputs = {"input"}; + dequant_op.outputs = {"output"}; + auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); + const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get(); + + Model int16_model; + Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]); + input_int16_array.data_type = ArrayDataType::kInt16; + OperatorSignature int16_signature = {.op = &dequant_op, + .model = &int16_model}; + EXPECT_EQ(op->GetVersion(int16_signature), 3); + + Model float16_model; + Array& input_float16_array = + float16_model.GetOrCreateArray(dequant_op.inputs[0]); + input_float16_array.data_type = ArrayDataType::kFloat16; + OperatorSignature float16_signature = {.op = &dequant_op, + .model = &float16_model}; + EXPECT_EQ(op->GetVersion(float16_signature), 3); + + Model int8_model; + Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]); + input_int8_array.data_type = ArrayDataType::kInt8; + OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model}; + EXPECT_EQ(op->GetVersion(int8_signature), 2); + + Model float_model; + Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]); + input_float_array.data_type = ArrayDataType::kFloat; + OperatorSignature float_signature = {.op = &dequant_op, + .model = &float_model}; + EXPECT_EQ(op->GetVersion(float_signature), 1); +} + TEST_F(OperatorTest, VersioningConv2DTest) { ConvOperator conv_op; conv_op.inputs = {"input", "filter"};