Add int16 support to Dequant.

PiperOrigin-RevId: 258917873
This commit is contained in:
Jian Li 2019-07-19 00:17:43 -07:00 committed by TensorFlower Gardener
parent 2ecc2fffad
commit 8fe40fa1ea
8 changed files with 86 additions and 10 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <string.h>
#include <cstdint>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
@ -64,6 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteInt16 ||
op_context.input->type == kTfLiteFloat16);
op_context.output->type = kTfLiteFloat32;
@ -95,12 +97,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<float>(op_context.output));
break;
case kTfLiteInt8:
reference_integer_ops::Dequantize(
reference_integer_ops::Dequantize<int8_t>(
op_params, GetTensorShape(op_context.input),
GetTensorData<int8_t>(op_context.input),
GetTensorShape(op_context.output),
GetTensorData<float>(op_context.output));
break;
case kTfLiteInt16:
reference_integer_ops::Dequantize<int16_t>(
op_params, GetTensorShape(op_context.input),
GetTensorData<int16_t>(op_context.input),
GetTensorShape(op_context.output),
GetTensorData<float>(op_context.output));
break;
case kTfLiteFloat16: {
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
GetTensorData<TfLiteFloat16>(op_context.input));

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <cstdint>
#include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -23,6 +24,15 @@ limitations under the License.
#include "tensorflow/lite/model.h"
namespace tflite {
namespace ops {
namespace builtin {
TfLiteRegistration* Register_DEQUANTIZE();
} // namespace builtin
} // namespace ops
namespace {
using ::testing::ElementsAreArray;
@ -30,13 +40,17 @@ using ::testing::ElementsAreArray;
class DequantizeOpModel : public SingleOpModel {
public:
DequantizeOpModel(TensorType type, std::initializer_list<int> shape,
float scale, int32_t zero_point) {
float scale, int32_t zero_point, int version) {
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
input_ = AddInput(input_tensor_data);
output_ = AddOutput({TensorType_FLOAT32, shape});
SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
CreateDequantizeOptions(builder_).Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_DEQUANTIZE, ops::builtin::Register_DEQUANTIZE(),
version);
BuildInterpreter({GetShape(input_)});
}
@ -54,7 +68,7 @@ class DequantizeOpModel : public SingleOpModel {
TEST(DequantizeOpTest, Uint8) {
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127);
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);
m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
m.Invoke();
@ -65,7 +79,7 @@ TEST(DequantizeOpTest, Uint8) {
TEST(DequantizeOpTest, Int8) {
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1);
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2);
m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
m.Invoke();
@ -75,7 +89,7 @@ TEST(DequantizeOpTest, Int8) {
}
TEST(DequantizeOpTest, Float16) {
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0);
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0, 3);
std::vector<Eigen::half> half{Eigen::half{-535.54f}, Eigen::half{-100.0f},
Eigen::half{-1.0f}, Eigen::half{0.f},
@ -88,5 +102,14 @@ TEST(DequantizeOpTest, Float16) {
/*max_abs_error=*/0.1f)));
}
TEST(DequantizeOpTest, Int16) {
DequantizeOpModel m(TensorType_INT16, {2, 5}, 0.5, -1, 4);
m.SetInput<int16_t>({-130, -127, -126, -125, -124, 123, 124, 125, 126, 130});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{-64.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 65.5})));
}
} // namespace
} // namespace tflite

View File

@ -22,15 +22,16 @@ limitations under the License.
namespace tflite {
namespace reference_integer_ops {
template <typename T>
inline void Dequantize(const tflite::DequantizationParams& op_params,
const RuntimeShape& input_shape, const int8* input_data,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int32 zero_point = op_params.zero_point;
const double scale = op_params.scale;
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const int32 val = input_data[i];
const int32 val = static_cast<int32>(input_data[i]);
const float result = static_cast<float>(scale * (val - zero_point));
output_data[i] = result;
}

View File

@ -280,7 +280,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version */ 1,
/* max_version */ 2);
/* max_version */ 4);
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
/* min_version */ 1,

View File

@ -117,10 +117,11 @@ struct TensorData {
class SingleOpResolver : public OpResolver {
public:
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration,
int version = 1)
: op_(op), registration_(*registration) {
registration_.builtin_code = static_cast<int32_t>(op);
registration_.version = 1;
registration_.version = version;
}
const TfLiteRegistration* FindOp(BuiltinOperator op,
int version) const override {

View File

@ -155,6 +155,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kWhere, 1}, "1.14.0"},
{{OperatorType::kDequantize, 1}, "1.13.1"},
{{OperatorType::kDequantize, 2}, "1.14.0"},
{{OperatorType::kDequantize, 3}, kPendingReleaseOpVersion},
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
{{OperatorType::kEqual, 1}, "1.14.0"},
{{OperatorType::kEqual, 2}, "1.14.0"},

View File

@ -2237,6 +2237,11 @@ class Dequantize
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 3 supports signed int16 input types.
if (input_array.data_type == ArrayDataType::kInt16 ||
input_array.data_type == ArrayDataType::kFloat16) {
return 3;
}
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;

View File

@ -974,6 +974,42 @@ TEST_F(OperatorTest, VersioningFullyConnectedTest) {
EXPECT_EQ(op->GetVersion(int8_signature), 4);
}
TEST_F(OperatorTest, VersioningDequantizeTest) {
DequantizeOperator dequant_op;
dequant_op.inputs = {"input"};
dequant_op.outputs = {"output"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
Model int16_model;
Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
input_int16_array.data_type = ArrayDataType::kInt16;
OperatorSignature int16_signature = {.op = &dequant_op,
.model = &int16_model};
EXPECT_EQ(op->GetVersion(int16_signature), 3);
Model float16_model;
Array& input_float16_array =
float16_model.GetOrCreateArray(dequant_op.inputs[0]);
input_float16_array.data_type = ArrayDataType::kFloat16;
OperatorSignature float16_signature = {.op = &dequant_op,
.model = &float16_model};
EXPECT_EQ(op->GetVersion(float16_signature), 3);
Model int8_model;
Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
input_int8_array.data_type = ArrayDataType::kInt8;
OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
EXPECT_EQ(op->GetVersion(int8_signature), 2);
Model float_model;
Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
input_float_array.data_type = ArrayDataType::kFloat;
OperatorSignature float_signature = {.op = &dequant_op,
.model = &float_model};
EXPECT_EQ(op->GetVersion(float_signature), 1);
}
TEST_F(OperatorTest, VersioningConv2DTest) {
ConvOperator conv_op;
conv_op.inputs = {"input", "filter"};