Implement TFLite Quantize operation.

PiperOrigin-RevId: 239312773
This commit is contained in:
Suharsh Sivakumar 2019-03-19 18:19:01 -07:00 committed by TensorFlower Gardener
parent 02f54c9a54
commit 6ee35631af
11 changed files with 349 additions and 7 deletions

View File

@ -139,6 +139,7 @@ typedef enum {
kTfLiteBuiltinElu = 111,
kTfLiteBuiltinReverseSequence = 112,
kTfLiteBuiltinMatrixDiag = 113,
kTfLiteBuiltinQuantize = 114,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -744,6 +744,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_GATHER_ND:
case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
case BuiltinOperator_QUANTIZE:
break;
}
return kTfLiteOk;

View File

@ -199,6 +199,7 @@ cc_library(
"pad.cc",
"pooling.cc",
"pow.cc",
"quantize.cc",
"range.cc",
"rank.cc",
"reduce.cc",
@ -240,6 +241,7 @@ cc_library(
deps = [
":activation_functor",
":eigen_support",
":gemm_support",
":kernel_util",
":lstm_eval",
":op_macros",
@ -247,7 +249,6 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:gemm_support",
"//tensorflow/lite/kernels/internal:audio_utils",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized",
@ -1401,3 +1402,16 @@ cc_test(
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "quantize_test",
size = "small",
srcs = ["quantize_test.cc"],
deps = [
":builtin_ops",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels/internal:types",
"@com_google_googletest//:gtest",
],
)

View File

@ -47,6 +47,7 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::AffineQuantize;
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
using reference_ops::Broadcast4DSlowGreater;

View File

@ -2878,6 +2878,26 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
}
}
template <typename T>
inline void AffineQuantize(const tflite::QuantizationParams& op_params,
const RuntimeShape& input_shape,
const float* input_data,
const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Quantize");
const int32 zero_point = op_params.zero_point;
const double scale = static_cast<double>(op_params.scale);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
static constexpr int32 min_val = std::numeric_limits<T>::min();
static constexpr int32 max_val = std::numeric_limits<T>::max();
for (int i = 0; i < flat_size; i++) {
const float val = input_data[i];
int32 unclamped = static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
int32 clamped = std::min(std::max(unclamped, min_val), max_val);
output_data[i] = clamped;
}
}
inline void FakeQuant(const tflite::FakeQuantParams& op_params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {

View File

@ -0,0 +1,100 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace quantize {
struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, 0);
output = GetOutput(context, node, 0);
}
const TfLiteTensor* input;
TfLiteTensor* output;
};
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
OpContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteFloat32);
TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 ||
op_context.output->type == kTfLiteInt8);
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.
TF_LITE_ENSURE_EQ(context, op_context.output->quantization.type,
kTfLiteAffineQuantization);
const auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
op_context.output->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
return context->ResizeTensor(context, op_context.output,
TfLiteIntArrayCopy(op_context.input->dims));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
tflite::QuantizationParams op_params;
op_params.zero_point = op_context.output->params.zero_point;
op_params.scale = op_context.output->params.scale;
switch (op_context.output->type) {
case kTfLiteUInt8:
optimized_ops::AffineQuantize(op_params, GetTensorShape(op_context.input),
GetTensorData<float>(op_context.input),
GetTensorShape(op_context.output),
GetTensorData<uint8_t>(op_context.output));
break;
case kTfLiteInt8:
optimized_ops::AffineQuantize(op_params, GetTensorShape(op_context.input),
GetTensorData<float>(op_context.input),
GetTensorShape(op_context.output),
GetTensorData<int8_t>(op_context.output));
break;
default:
context->ReportError(context, "Type %d not supported.",
op_context.input->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace quantize
TfLiteRegistration* Register_QUANTIZE_OPT() {
static TfLiteRegistration r = {nullptr, nullptr, quantize::Prepare,
quantize::Eval};
return &r;
}
TfLiteRegistration* Register_QUANTIZE() { return Register_QUANTIZE_OPT(); }
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class QuantizeOpModel : public SingleOpModel {
public:
QuantizeOpModel(TensorType type, std::initializer_list<int> shape,
float scale, int32_t zero_point) {
const TensorData output_tensor_data = {type, shape, 0,
0, scale, zero_point};
input_ = AddInput({TensorType_FLOAT32, shape});
output_ = AddOutput(output_tensor_data);
SetBuiltinOp(BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions,
CreateQuantizeOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
void SetInput(std::initializer_list<float> data) {
PopulateTensor(input_, data);
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
private:
int input_;
int output_;
};
TEST(QuantizeOpTest, UINT8) {
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
QuantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127);
m.SetInput({-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}));
}
TEST(QuantizeOpTest, INT8) {
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
QuantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1);
m.SetInput({-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(
{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -140,6 +140,7 @@ TfLiteRegistration* Register_WHERE();
TfLiteRegistration* Register_ELU();
TfLiteRegistration* Register_REVERSE_SEQUENCE();
TfLiteRegistration* Register_MATRIX_DIAG();
TfLiteRegistration* Register_QUANTIZE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@ -376,6 +377,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -672,6 +672,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_ELU:
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
case tflite::BuiltinOperator_MATRIX_DIAG:
case tflite::BuiltinOperator_QUANTIZE:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -227,6 +227,7 @@ enum BuiltinOperator : byte {
ELU = 111,
REVERSE_SEQUENCE = 112,
MATRIX_DIAG = 113,
QUANTIZE = 114,
}
// Options for the builtin operators.
@ -319,6 +320,7 @@ union BuiltinOptions {
RankOptions,
ReverseSequenceOptions,
MatrixDiagOptions,
QuantizeOptions,
}
enum Padding : byte { SAME, VALID }
@ -762,6 +764,9 @@ table ReverseSequenceOptions {
table MatrixDiagOptions {
}
table QuantizeOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -295,6 +295,9 @@ struct ReverseSequenceOptionsT;
struct MatrixDiagOptions;
struct MatrixDiagOptionsT;
struct QuantizeOptions;
struct QuantizeOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -558,11 +561,12 @@ enum BuiltinOperator {
BuiltinOperator_ELU = 111,
BuiltinOperator_REVERSE_SEQUENCE = 112,
BuiltinOperator_MATRIX_DIAG = 113,
BuiltinOperator_QUANTIZE = 114,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG
BuiltinOperator_MAX = BuiltinOperator_QUANTIZE
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -676,7 +680,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
BuiltinOperator_RANK,
BuiltinOperator_ELU,
BuiltinOperator_REVERSE_SEQUENCE,
BuiltinOperator_MATRIX_DIAG
BuiltinOperator_MATRIX_DIAG,
BuiltinOperator_QUANTIZE
};
return values;
}
@ -797,6 +802,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"ELU",
"REVERSE_SEQUENCE",
"MATRIX_DIAG",
"QUANTIZE",
nullptr
};
return names;
@ -897,11 +903,12 @@ enum BuiltinOptions {
BuiltinOptions_RankOptions = 86,
BuiltinOptions_ReverseSequenceOptions = 87,
BuiltinOptions_MatrixDiagOptions = 88,
BuiltinOptions_QuantizeOptions = 89,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions
BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -991,7 +998,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
BuiltinOptions_WhereOptions,
BuiltinOptions_RankOptions,
BuiltinOptions_ReverseSequenceOptions,
BuiltinOptions_MatrixDiagOptions
BuiltinOptions_MatrixDiagOptions,
BuiltinOptions_QuantizeOptions
};
return values;
}
@ -1087,6 +1095,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"RankOptions",
"ReverseSequenceOptions",
"MatrixDiagOptions",
"QuantizeOptions",
nullptr
};
return names;
@ -1453,6 +1462,10 @@ template<> struct BuiltinOptionsTraits<MatrixDiagOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
};
template<> struct BuiltinOptionsTraits<QuantizeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2188,6 +2201,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_MatrixDiagOptions ?
reinterpret_cast<const MatrixDiagOptionsT *>(value) : nullptr;
}
QuantizeOptionsT *AsQuantizeOptions() {
return type == BuiltinOptions_QuantizeOptions ?
reinterpret_cast<QuantizeOptionsT *>(value) : nullptr;
}
const QuantizeOptionsT *AsQuantizeOptions() const {
return type == BuiltinOptions_QuantizeOptions ?
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7630,6 +7651,46 @@ inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(
flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct QuantizeOptionsT : public flatbuffers::NativeTable {
typedef QuantizeOptions TableType;
QuantizeOptionsT() {
}
};
struct QuantizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef QuantizeOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
QuantizeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(QuantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<QuantizeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct QuantizeOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit QuantizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
QuantizeOptionsBuilder &operator=(const QuantizeOptionsBuilder &);
flatbuffers::Offset<QuantizeOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<QuantizeOptions>(end);
return o;
}
};
inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
QuantizeOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -8027,6 +8088,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const {
return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast<const MatrixDiagOptions *>(builtin_options()) : nullptr;
}
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -8410,6 +8474,10 @@ template<> inline const MatrixDiagOptions *Operator::builtin_options_as<MatrixDi
return builtin_options_as_MatrixDiagOptions();
}
template<> inline const QuantizeOptions *Operator::builtin_options_as<QuantizeOptions>() const {
return builtin_options_as_QuantizeOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -11247,6 +11315,29 @@ inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffer
_fbb);
}
inline QuantizeOptionsT *QuantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new QuantizeOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void QuantizeOptions::UnPackTo(QuantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<QuantizeOptions> QuantizeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateQuantizeOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateQuantizeOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -11857,6 +11948,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_QuantizeOptions: {
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -12227,6 +12322,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_QuantizeOptions: {
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -12585,6 +12684,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const MatrixDiagOptionsT *>(value);
return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_QuantizeOptions: {
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -12943,6 +13046,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new MatrixDiagOptionsT(*reinterpret_cast<MatrixDiagOptionsT *>(u.value));
break;
}
case BuiltinOptions_QuantizeOptions: {
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
break;
}
default:
break;
}
@ -13390,6 +13497,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_QuantizeOptions: {
auto ptr = reinterpret_cast<QuantizeOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;