Create int8 L2Norm.

PiperOrigin-RevId: 235623180
This commit is contained in:
Jian Li 2019-02-25 16:24:00 -08:00 committed by TensorFlower Gardener
parent 391bee7364
commit 8d4cdf8444
10 changed files with 219 additions and 112 deletions

View File

@ -311,6 +311,7 @@ cc_library(
"reference/integer_ops/depthwise_conv.h",
"reference/integer_ops/dequantize.h",
"reference/integer_ops/fully_connected.h",
"reference/integer_ops/l2normalization.h",
"reference/integer_ops/log_softmax.h",
"reference/integer_ops/logistic.h",
"reference/integer_ops/mul.h",

View File

@ -363,6 +363,55 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits,
return shifted_scale.raw();
}
inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift,
int32* output_inv_sqrt,
int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
const unsigned max_left_shift_bits =
CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
input <<= 2 * left_shift_bit_pairs;
TFLITE_DCHECK_GE(input, (1 << 27));
TFLITE_DCHECK_LT(input, (1 << 29));
using gemmlowp::FixedPoint;
using gemmlowp::Rescale;
using gemmlowp::SaturatingRoundingMultiplyByPOT;
// Using 3 integer bits gives us enough room for the internal arithmetic in
// this Newton-Raphson iteration.
using F3 = FixedPoint<int32, 3>;
using F0 = FixedPoint<int32, 0>;
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
const F3 fixedpoint_half_input =
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
const F3 fixedpoint_half_three =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
// Newton-Raphson iteration
// Naive unoptimized starting guess: x = 1
F3 x = F3::One();
// Naive unoptimized number of iterations: 5
for (int i = 0; i < 5; i++) {
const F3 x3 = Rescale<3>(x * x * x);
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
}
const F0 fixedpoint_half_sqrt_2 =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
x = x * fixedpoint_half_sqrt_2;
*output_inv_sqrt = x.raw();
if (*output_shift < 0) {
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
// Convert right shift (right is positive) to left shift.
*output_shift *= reverse_shift;
}
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
// BROADCASTING.
//

View File

@ -2357,55 +2357,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
int32* output_inv_sqrt,
int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
const unsigned max_left_shift_bits =
CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
input <<= 2 * left_shift_bit_pairs;
TFLITE_DCHECK_GE(input, (1 << 27));
TFLITE_DCHECK_LT(input, (1 << 29));
using gemmlowp::FixedPoint;
using gemmlowp::Rescale;
using gemmlowp::SaturatingRoundingMultiplyByPOT;
// Using 3 integer bits gives us enough room for the internal arithmetic in
// this Newton-Raphson iteration.
using F3 = FixedPoint<int32, 3>;
using F0 = FixedPoint<int32, 0>;
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
const F3 fixedpoint_half_input =
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
const F3 fixedpoint_half_three =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
// Newton-Raphson iteration
// Naive unoptimized starting guess: x = 1
F3 x = F3::One();
// Naive unoptimized number of iterations: 5
for (int i = 0; i < 5; i++) {
const F3 x3 = Rescale<3>(x * x * x);
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
}
const F0 fixedpoint_half_sqrt_2 =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
x = x * fixedpoint_half_sqrt_2;
*output_inv_sqrt = x.raw();
if (*output_shift < 0) {
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
// Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
const uint8* input_data,
@ -2427,8 +2378,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
int32 inv_l2norm_multiplier;
int inv_l2norm_shift;
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
&inv_l2norm_shift);
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
&inv_l2norm_multiplier, &inv_l2norm_shift);
for (int c = 0; c < depth; c++) {
int32 diff = *input_data - input_zero_point;

View File

@ -0,0 +1,65 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_integer_ops {
inline void L2Normalization(int32_t input_zero_point, int32_t outer_size,
int32_t depth, const int8* input_data,
int8* output_data) {
static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
// The output scale must be in sync with Prepare().
// Output is in 1/128 scale so the actual output range is nudged from [-1, 1]
// to [-1, 127/128].
static constexpr int32_t kOutputScale = 7;
for (int outer_index = 0; outer_index < outer_size; ++outer_index) {
// int32 = (int8 - int8) ^ 2.
// ([-128, 127] - [-128, 127]) ^ 2 = [0, (2^8 - 1)^2] so the accumulator is
// safe from overflowing in at least 2^16 steps.
int32_t acc = 0;
for (int inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input =
input_data[depth * outer_index + inner_index] - input_zero_point;
acc += input * input;
}
int32_t inv_l2norm_multiplier;
int inv_l2norm_shift;
GetInvSqrtQuantizedMultiplierExp(acc, /*reverse_shift*/ -1,
&inv_l2norm_multiplier, &inv_l2norm_shift);
for (int inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input =
input_data[depth * outer_index + inner_index] - input_zero_point;
// Rescale and downcast. Rescale is folded into the division.
int32_t output_in_q24 = MultiplyByQuantizedMultiplier(
input, inv_l2norm_multiplier, inv_l2norm_shift + kOutputScale);
output_in_q24 =
std::min(static_cast<int32_t>(kMaxInt8),
std::max(static_cast<int32_t>(kMinInt8), output_in_q24));
output_data[depth * outer_index + inner_index] =
static_cast<int8>(output_in_q24);
}
}
}
} // namespace reference_integer_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_

View File

@ -489,55 +489,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
int32* output_inv_sqrt,
int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
const unsigned max_left_shift_bits =
CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
input <<= 2 * left_shift_bit_pairs;
TFLITE_DCHECK_GE(input, (1 << 27));
TFLITE_DCHECK_LT(input, (1 << 29));
using gemmlowp::FixedPoint;
using gemmlowp::Rescale;
using gemmlowp::SaturatingRoundingMultiplyByPOT;
// Using 3 integer bits gives us enough room for the internal arithmetic in
// this Newton-Raphson iteration.
using F3 = FixedPoint<int32, 3>;
using F0 = FixedPoint<int32, 0>;
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
const F3 fixedpoint_half_input =
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
const F3 fixedpoint_half_three =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
// Newton-Raphson iteration
// Naive unoptimized starting guess: x = 1
F3 x = F3::One();
// Naive unoptimized number of iterations: 5
for (int i = 0; i < 5; i++) {
const F3 x3 = Rescale<3>(x * x * x);
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
}
const F0 fixedpoint_half_sqrt_2 =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
x = x * fixedpoint_half_sqrt_2;
*output_inv_sqrt = x.raw();
if (*output_shift < 0) {
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
// Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
const uint8* input_data,
@ -557,9 +508,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
int32 inv_l2norm_multiplier;
int inv_l2norm_shift;
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
&inv_l2norm_shift);
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
&inv_l2norm_multiplier, &inv_l2norm_shift);
for (int c = 0; c < depth; c++) {
int32 diff = input_data[depth * i + c] - input_zero_point;
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -45,13 +46,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
TF_LITE_ENSURE(
context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
output->type == kTfLiteUInt8 ||
output->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
if (output->type == kTfLiteUInt8) {
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
if (output->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
}
if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
}
}
// TODO(ahentz): For some reason our implementations don't support
@ -97,6 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_L2NORM(optimized_ops);
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteInt8) {
const auto input_shape = GetTensorShape(input);
const auto output_shape = GetTensorShape(output);
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
depth, GetTensorData<int8>(input),
GetTensorData<int8>(output));
} else {
context->ReportError(context, "Output type is %d, requires float.",
output->type);

View File

@ -55,9 +55,10 @@ class L2NormOpModel : public SingleOpModel {
return ExtractVector<T>(output_);
}
template <typename T>
std::vector<float> GetDequantizedOutput() {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
GetZeroPoint(output_));
}
int input() const { return input_; }
@ -100,7 +101,20 @@ TEST(L2NormOpTest, SimpleUint8Test) {
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({58, 166, 173, 205, 83, 134}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
}
TEST(L2NormOpTest, SimpleInt8Test) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
m.QuantizeAndPopulate<int8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({-70, 38, 45, 77, -45, 6}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
}
@ -121,7 +135,32 @@ TEST(L2NormOpTest, MultipleBatchUint8Test) {
58, 166, 173, 205, 83, 134, // batch 2
58, 166, 173, 205, 83, 134, // batch 3
}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
},
0.1)));
}
TEST(L2NormOpTest, MultipleBatchInt8Test) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
m.QuantizeAndPopulate<int8_t>(m.input(),
{
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
-70, 38, 45, 77, -45, 6, // batch 1
-70, 38, 45, 77, -45, 6, // batch 2
-70, 38, 45, 77, -45, 6, // batch 3
}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1

View File

@ -229,7 +229,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORMALIZATION());
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,

View File

@ -618,6 +618,12 @@ class L2Normalization
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& output_name = op_signature.op->outputs[0];
const Array& output_array = op_signature.model->GetArray(output_name);
// Version 2 supports signed int8 input types.
if (output_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -765,6 +765,28 @@ void SimpleVersioningTest() {
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
}
// Test version for a simple Op with 2 versions and the output type controls the
// version.
template <typename Op>
void SimpleOutputVersioningTest() {
Op op;
op.outputs = {"output1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
Model uint8_model;
Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
uint8_array.data_type = ArrayDataType::kUint8;
OperatorSignature uint8_signature = {.model = &uint8_model, .op = &op};
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
Model int8_model;
Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
int8_array.data_type = ArrayDataType::kInt8;
OperatorSignature int8_signature = {.model = &int8_model, .op = &op};
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
}
TEST_F(OperatorTest, VersioningEqualTest) {
SimpleVersioningTest<TensorFlowEqualOperator>();
}
@ -825,6 +847,10 @@ TEST_F(OperatorTest, VersioningLogisticTest) {
SimpleVersioningTest<LogisticOperator>();
}
TEST_F(OperatorTest, VersioningL2NormTest) {
SimpleOutputVersioningTest<L2NormalizationOperator>();
}
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }