Add int16 support to Dequant.
PiperOrigin-RevId: 258917873
This commit is contained in:
parent
2ecc2fffad
commit
8fe40fa1ea
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
@ -64,6 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
||||||
op_context.input->type == kTfLiteInt8 ||
|
op_context.input->type == kTfLiteInt8 ||
|
||||||
|
op_context.input->type == kTfLiteInt16 ||
|
||||||
op_context.input->type == kTfLiteFloat16);
|
op_context.input->type == kTfLiteFloat16);
|
||||||
|
|
||||||
op_context.output->type = kTfLiteFloat32;
|
op_context.output->type = kTfLiteFloat32;
|
||||||
@ -95,12 +97,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTensorData<float>(op_context.output));
|
GetTensorData<float>(op_context.output));
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
reference_integer_ops::Dequantize(
|
reference_integer_ops::Dequantize<int8_t>(
|
||||||
op_params, GetTensorShape(op_context.input),
|
op_params, GetTensorShape(op_context.input),
|
||||||
GetTensorData<int8_t>(op_context.input),
|
GetTensorData<int8_t>(op_context.input),
|
||||||
GetTensorShape(op_context.output),
|
GetTensorShape(op_context.output),
|
||||||
GetTensorData<float>(op_context.output));
|
GetTensorData<float>(op_context.output));
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
reference_integer_ops::Dequantize<int16_t>(
|
||||||
|
op_params, GetTensorShape(op_context.input),
|
||||||
|
GetTensorData<int16_t>(op_context.input),
|
||||||
|
GetTensorShape(op_context.output),
|
||||||
|
GetTensorData<float>(op_context.output));
|
||||||
|
break;
|
||||||
case kTfLiteFloat16: {
|
case kTfLiteFloat16: {
|
||||||
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
|
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
|
||||||
GetTensorData<TfLiteFloat16>(op_context.input));
|
GetTensorData<TfLiteFloat16>(op_context.input));
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
@ -23,6 +24,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_DEQUANTIZE();
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
@ -30,13 +40,17 @@ using ::testing::ElementsAreArray;
|
|||||||
class DequantizeOpModel : public SingleOpModel {
|
class DequantizeOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
DequantizeOpModel(TensorType type, std::initializer_list<int> shape,
|
DequantizeOpModel(TensorType type, std::initializer_list<int> shape,
|
||||||
float scale, int32_t zero_point) {
|
float scale, int32_t zero_point, int version) {
|
||||||
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
|
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
|
||||||
input_ = AddInput(input_tensor_data);
|
input_ = AddInput(input_tensor_data);
|
||||||
output_ = AddOutput({TensorType_FLOAT32, shape});
|
output_ = AddOutput({TensorType_FLOAT32, shape});
|
||||||
SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
|
SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
|
||||||
CreateDequantizeOptions(builder_).Union());
|
CreateDequantizeOptions(builder_).Union());
|
||||||
|
|
||||||
|
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||||
|
BuiltinOperator_DEQUANTIZE, ops::builtin::Register_DEQUANTIZE(),
|
||||||
|
version);
|
||||||
|
|
||||||
BuildInterpreter({GetShape(input_)});
|
BuildInterpreter({GetShape(input_)});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,7 +68,7 @@ class DequantizeOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
TEST(DequantizeOpTest, Uint8) {
|
TEST(DequantizeOpTest, Uint8) {
|
||||||
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
|
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
|
||||||
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127);
|
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);
|
||||||
|
|
||||||
m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
|
m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
@ -65,7 +79,7 @@ TEST(DequantizeOpTest, Uint8) {
|
|||||||
|
|
||||||
TEST(DequantizeOpTest, Int8) {
|
TEST(DequantizeOpTest, Int8) {
|
||||||
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
||||||
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1);
|
DequantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2);
|
||||||
|
|
||||||
m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
|
m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
@ -75,7 +89,7 @@ TEST(DequantizeOpTest, Int8) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DequantizeOpTest, Float16) {
|
TEST(DequantizeOpTest, Float16) {
|
||||||
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0);
|
DequantizeOpModel m(TensorType_FLOAT16, {2, 3}, 1.0f, 0, 3);
|
||||||
|
|
||||||
std::vector<Eigen::half> half{Eigen::half{-535.54f}, Eigen::half{-100.0f},
|
std::vector<Eigen::half> half{Eigen::half{-535.54f}, Eigen::half{-100.0f},
|
||||||
Eigen::half{-1.0f}, Eigen::half{0.f},
|
Eigen::half{-1.0f}, Eigen::half{0.f},
|
||||||
@ -88,5 +102,14 @@ TEST(DequantizeOpTest, Float16) {
|
|||||||
/*max_abs_error=*/0.1f)));
|
/*max_abs_error=*/0.1f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DequantizeOpTest, Int16) {
|
||||||
|
DequantizeOpModel m(TensorType_INT16, {2, 5}, 0.5, -1, 4);
|
||||||
|
m.SetInput<int16_t>({-130, -127, -126, -125, -124, 123, 124, 125, 126, 130});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{-64.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 65.5})));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -22,15 +22,16 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace reference_integer_ops {
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
inline void Dequantize(const tflite::DequantizationParams& op_params,
|
inline void Dequantize(const tflite::DequantizationParams& op_params,
|
||||||
const RuntimeShape& input_shape, const int8* input_data,
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
const RuntimeShape& output_shape, float* output_data) {
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
const int32 zero_point = op_params.zero_point;
|
const int32 zero_point = op_params.zero_point;
|
||||||
const double scale = op_params.scale;
|
const double scale = op_params.scale;
|
||||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
|
||||||
for (int i = 0; i < flat_size; i++) {
|
for (int i = 0; i < flat_size; i++) {
|
||||||
const int32 val = input_data[i];
|
const int32 val = static_cast<int32>(input_data[i]);
|
||||||
const float result = static_cast<float>(scale * (val - zero_point));
|
const float result = static_cast<float>(scale * (val - zero_point));
|
||||||
output_data[i] = result;
|
output_data[i] = result;
|
||||||
}
|
}
|
||||||
|
@ -280,7 +280,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 4);
|
||||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||||
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
|
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
|
@ -117,10 +117,11 @@ struct TensorData {
|
|||||||
|
|
||||||
class SingleOpResolver : public OpResolver {
|
class SingleOpResolver : public OpResolver {
|
||||||
public:
|
public:
|
||||||
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
|
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration,
|
||||||
|
int version = 1)
|
||||||
: op_(op), registration_(*registration) {
|
: op_(op), registration_(*registration) {
|
||||||
registration_.builtin_code = static_cast<int32_t>(op);
|
registration_.builtin_code = static_cast<int32_t>(op);
|
||||||
registration_.version = 1;
|
registration_.version = version;
|
||||||
}
|
}
|
||||||
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
||||||
int version) const override {
|
int version) const override {
|
||||||
|
@ -155,6 +155,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kWhere, 1}, "1.14.0"},
|
{{OperatorType::kWhere, 1}, "1.14.0"},
|
||||||
{{OperatorType::kDequantize, 1}, "1.13.1"},
|
{{OperatorType::kDequantize, 1}, "1.13.1"},
|
||||||
{{OperatorType::kDequantize, 2}, "1.14.0"},
|
{{OperatorType::kDequantize, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kDequantize, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
||||||
{{OperatorType::kEqual, 1}, "1.14.0"},
|
{{OperatorType::kEqual, 1}, "1.14.0"},
|
||||||
{{OperatorType::kEqual, 2}, "1.14.0"},
|
{{OperatorType::kEqual, 2}, "1.14.0"},
|
||||||
|
@ -2237,6 +2237,11 @@ class Dequantize
|
|||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
const string& input_name = op_signature.op->inputs[0];
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// Version 3 supports signed int16 input types.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt16 ||
|
||||||
|
input_array.data_type == ArrayDataType::kFloat16) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
// Version 2 supports signed int8 input types.
|
// Version 2 supports signed int8 input types.
|
||||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -974,6 +974,42 @@ TEST_F(OperatorTest, VersioningFullyConnectedTest) {
|
|||||||
EXPECT_EQ(op->GetVersion(int8_signature), 4);
|
EXPECT_EQ(op->GetVersion(int8_signature), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningDequantizeTest) {
|
||||||
|
DequantizeOperator dequant_op;
|
||||||
|
dequant_op.inputs = {"input"};
|
||||||
|
dequant_op.outputs = {"output"};
|
||||||
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||||
|
const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
|
||||||
|
|
||||||
|
Model int16_model;
|
||||||
|
Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||||
|
input_int16_array.data_type = ArrayDataType::kInt16;
|
||||||
|
OperatorSignature int16_signature = {.op = &dequant_op,
|
||||||
|
.model = &int16_model};
|
||||||
|
EXPECT_EQ(op->GetVersion(int16_signature), 3);
|
||||||
|
|
||||||
|
Model float16_model;
|
||||||
|
Array& input_float16_array =
|
||||||
|
float16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||||
|
input_float16_array.data_type = ArrayDataType::kFloat16;
|
||||||
|
OperatorSignature float16_signature = {.op = &dequant_op,
|
||||||
|
.model = &float16_model};
|
||||||
|
EXPECT_EQ(op->GetVersion(float16_signature), 3);
|
||||||
|
|
||||||
|
Model int8_model;
|
||||||
|
Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||||
|
input_int8_array.data_type = ArrayDataType::kInt8;
|
||||||
|
OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
|
||||||
|
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
||||||
|
|
||||||
|
Model float_model;
|
||||||
|
Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
|
||||||
|
input_float_array.data_type = ArrayDataType::kFloat;
|
||||||
|
OperatorSignature float_signature = {.op = &dequant_op,
|
||||||
|
.model = &float_model};
|
||||||
|
EXPECT_EQ(op->GetVersion(float_signature), 1);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningConv2DTest) {
|
TEST_F(OperatorTest, VersioningConv2DTest) {
|
||||||
ConvOperator conv_op;
|
ConvOperator conv_op;
|
||||||
conv_op.inputs = {"input", "filter"};
|
conv_op.inputs = {"input", "filter"};
|
||||||
|
Loading…
Reference in New Issue
Block a user