Create int8 L2Norm.
PiperOrigin-RevId: 235623180
This commit is contained in:
parent
391bee7364
commit
8d4cdf8444
@ -311,6 +311,7 @@ cc_library(
|
|||||||
"reference/integer_ops/depthwise_conv.h",
|
"reference/integer_ops/depthwise_conv.h",
|
||||||
"reference/integer_ops/dequantize.h",
|
"reference/integer_ops/dequantize.h",
|
||||||
"reference/integer_ops/fully_connected.h",
|
"reference/integer_ops/fully_connected.h",
|
||||||
|
"reference/integer_ops/l2normalization.h",
|
||||||
"reference/integer_ops/log_softmax.h",
|
"reference/integer_ops/log_softmax.h",
|
||||||
"reference/integer_ops/logistic.h",
|
"reference/integer_ops/logistic.h",
|
||||||
"reference/integer_ops/mul.h",
|
"reference/integer_ops/mul.h",
|
||||||
|
@ -363,6 +363,55 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits,
|
|||||||
return shifted_scale.raw();
|
return shifted_scale.raw();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift,
|
||||||
|
int32* output_inv_sqrt,
|
||||||
|
int* output_shift) {
|
||||||
|
*output_shift = 11;
|
||||||
|
while (input >= (1 << 29)) {
|
||||||
|
input /= 4;
|
||||||
|
++*output_shift;
|
||||||
|
}
|
||||||
|
TFLITE_DCHECK_GT(input, 0);
|
||||||
|
const unsigned max_left_shift_bits =
|
||||||
|
CountLeadingZeros(static_cast<uint32>(input)) - 1;
|
||||||
|
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
|
||||||
|
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
|
||||||
|
*output_shift -= left_shift_bit_pairs;
|
||||||
|
input <<= 2 * left_shift_bit_pairs;
|
||||||
|
TFLITE_DCHECK_GE(input, (1 << 27));
|
||||||
|
TFLITE_DCHECK_LT(input, (1 << 29));
|
||||||
|
using gemmlowp::FixedPoint;
|
||||||
|
using gemmlowp::Rescale;
|
||||||
|
using gemmlowp::SaturatingRoundingMultiplyByPOT;
|
||||||
|
// Using 3 integer bits gives us enough room for the internal arithmetic in
|
||||||
|
// this Newton-Raphson iteration.
|
||||||
|
using F3 = FixedPoint<int32, 3>;
|
||||||
|
using F0 = FixedPoint<int32, 0>;
|
||||||
|
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
|
||||||
|
const F3 fixedpoint_half_input =
|
||||||
|
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
|
||||||
|
const F3 fixedpoint_half_three =
|
||||||
|
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
|
||||||
|
// Newton-Raphson iteration
|
||||||
|
// Naive unoptimized starting guess: x = 1
|
||||||
|
F3 x = F3::One();
|
||||||
|
// Naive unoptimized number of iterations: 5
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
const F3 x3 = Rescale<3>(x * x * x);
|
||||||
|
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
|
||||||
|
}
|
||||||
|
const F0 fixedpoint_half_sqrt_2 =
|
||||||
|
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
|
||||||
|
x = x * fixedpoint_half_sqrt_2;
|
||||||
|
*output_inv_sqrt = x.raw();
|
||||||
|
if (*output_shift < 0) {
|
||||||
|
*output_inv_sqrt <<= -*output_shift;
|
||||||
|
*output_shift = 0;
|
||||||
|
}
|
||||||
|
// Convert right shift (right is positive) to left shift.
|
||||||
|
*output_shift *= reverse_shift;
|
||||||
|
}
|
||||||
|
|
||||||
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
|
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
|
||||||
// BROADCASTING.
|
// BROADCASTING.
|
||||||
//
|
//
|
||||||
|
@ -2357,55 +2357,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
|
|
||||||
int32* output_inv_sqrt,
|
|
||||||
int* output_shift) {
|
|
||||||
*output_shift = 11;
|
|
||||||
while (input >= (1 << 29)) {
|
|
||||||
input /= 4;
|
|
||||||
++*output_shift;
|
|
||||||
}
|
|
||||||
TFLITE_DCHECK_GT(input, 0);
|
|
||||||
const unsigned max_left_shift_bits =
|
|
||||||
CountLeadingZeros(static_cast<uint32>(input)) - 1;
|
|
||||||
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
|
|
||||||
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
|
|
||||||
*output_shift -= left_shift_bit_pairs;
|
|
||||||
input <<= 2 * left_shift_bit_pairs;
|
|
||||||
TFLITE_DCHECK_GE(input, (1 << 27));
|
|
||||||
TFLITE_DCHECK_LT(input, (1 << 29));
|
|
||||||
using gemmlowp::FixedPoint;
|
|
||||||
using gemmlowp::Rescale;
|
|
||||||
using gemmlowp::SaturatingRoundingMultiplyByPOT;
|
|
||||||
// Using 3 integer bits gives us enough room for the internal arithmetic in
|
|
||||||
// this Newton-Raphson iteration.
|
|
||||||
using F3 = FixedPoint<int32, 3>;
|
|
||||||
using F0 = FixedPoint<int32, 0>;
|
|
||||||
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
|
|
||||||
const F3 fixedpoint_half_input =
|
|
||||||
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
|
|
||||||
const F3 fixedpoint_half_three =
|
|
||||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
|
|
||||||
// Newton-Raphson iteration
|
|
||||||
// Naive unoptimized starting guess: x = 1
|
|
||||||
F3 x = F3::One();
|
|
||||||
// Naive unoptimized number of iterations: 5
|
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
const F3 x3 = Rescale<3>(x * x * x);
|
|
||||||
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
|
|
||||||
}
|
|
||||||
const F0 fixedpoint_half_sqrt_2 =
|
|
||||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
|
|
||||||
x = x * fixedpoint_half_sqrt_2;
|
|
||||||
*output_inv_sqrt = x.raw();
|
|
||||||
if (*output_shift < 0) {
|
|
||||||
*output_inv_sqrt <<= -*output_shift;
|
|
||||||
*output_shift = 0;
|
|
||||||
}
|
|
||||||
// Convert right shift (right is positive) to left shift.
|
|
||||||
*output_shift *= kReverseShift;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||||
const RuntimeShape& input_shape,
|
const RuntimeShape& input_shape,
|
||||||
const uint8* input_data,
|
const uint8* input_data,
|
||||||
@ -2427,8 +2378,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
}
|
}
|
||||||
int32 inv_l2norm_multiplier;
|
int32 inv_l2norm_multiplier;
|
||||||
int inv_l2norm_shift;
|
int inv_l2norm_shift;
|
||||||
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
|
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
|
||||||
&inv_l2norm_shift);
|
&inv_l2norm_multiplier, &inv_l2norm_shift);
|
||||||
|
|
||||||
for (int c = 0; c < depth; c++) {
|
for (int c = 0; c < depth; c++) {
|
||||||
int32 diff = *input_data - input_zero_point;
|
int32 diff = *input_data - input_zero_point;
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
inline void L2Normalization(int32_t input_zero_point, int32_t outer_size,
|
||||||
|
int32_t depth, const int8* input_data,
|
||||||
|
int8* output_data) {
|
||||||
|
static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
|
||||||
|
static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
|
||||||
|
// The output scale must be in sync with Prepare().
|
||||||
|
// Output is in 1/128 scale so the actual output range is nudged from [-1, 1]
|
||||||
|
// to [-1, 127/128].
|
||||||
|
static constexpr int32_t kOutputScale = 7;
|
||||||
|
for (int outer_index = 0; outer_index < outer_size; ++outer_index) {
|
||||||
|
// int32 = (int8 - int8) ^ 2.
|
||||||
|
// ([-128, 127] - [-128, 127]) ^ 2 = [0, (2^8 - 1)^2] so the accumulator is
|
||||||
|
// safe from overflowing in at least 2^16 steps.
|
||||||
|
int32_t acc = 0;
|
||||||
|
for (int inner_index = 0; inner_index < depth; ++inner_index) {
|
||||||
|
int32_t input =
|
||||||
|
input_data[depth * outer_index + inner_index] - input_zero_point;
|
||||||
|
acc += input * input;
|
||||||
|
}
|
||||||
|
int32_t inv_l2norm_multiplier;
|
||||||
|
int inv_l2norm_shift;
|
||||||
|
GetInvSqrtQuantizedMultiplierExp(acc, /*reverse_shift*/ -1,
|
||||||
|
&inv_l2norm_multiplier, &inv_l2norm_shift);
|
||||||
|
|
||||||
|
for (int inner_index = 0; inner_index < depth; ++inner_index) {
|
||||||
|
int32_t input =
|
||||||
|
input_data[depth * outer_index + inner_index] - input_zero_point;
|
||||||
|
|
||||||
|
// Rescale and downcast. Rescale is folded into the division.
|
||||||
|
int32_t output_in_q24 = MultiplyByQuantizedMultiplier(
|
||||||
|
input, inv_l2norm_multiplier, inv_l2norm_shift + kOutputScale);
|
||||||
|
output_in_q24 =
|
||||||
|
std::min(static_cast<int32_t>(kMaxInt8),
|
||||||
|
std::max(static_cast<int32_t>(kMinInt8), output_in_q24));
|
||||||
|
output_data[depth * outer_index + inner_index] =
|
||||||
|
static_cast<int8>(output_in_q24);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_
|
@ -489,55 +489,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
|
|
||||||
int32* output_inv_sqrt,
|
|
||||||
int* output_shift) {
|
|
||||||
*output_shift = 11;
|
|
||||||
while (input >= (1 << 29)) {
|
|
||||||
input /= 4;
|
|
||||||
++*output_shift;
|
|
||||||
}
|
|
||||||
TFLITE_DCHECK_GT(input, 0);
|
|
||||||
const unsigned max_left_shift_bits =
|
|
||||||
CountLeadingZeros(static_cast<uint32>(input)) - 1;
|
|
||||||
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
|
|
||||||
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
|
|
||||||
*output_shift -= left_shift_bit_pairs;
|
|
||||||
input <<= 2 * left_shift_bit_pairs;
|
|
||||||
TFLITE_DCHECK_GE(input, (1 << 27));
|
|
||||||
TFLITE_DCHECK_LT(input, (1 << 29));
|
|
||||||
using gemmlowp::FixedPoint;
|
|
||||||
using gemmlowp::Rescale;
|
|
||||||
using gemmlowp::SaturatingRoundingMultiplyByPOT;
|
|
||||||
// Using 3 integer bits gives us enough room for the internal arithmetic in
|
|
||||||
// this Newton-Raphson iteration.
|
|
||||||
using F3 = FixedPoint<int32, 3>;
|
|
||||||
using F0 = FixedPoint<int32, 0>;
|
|
||||||
const F3 fixedpoint_input = F3::FromRaw(input >> 1);
|
|
||||||
const F3 fixedpoint_half_input =
|
|
||||||
SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
|
|
||||||
const F3 fixedpoint_half_three =
|
|
||||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
|
|
||||||
// Newton-Raphson iteration
|
|
||||||
// Naive unoptimized starting guess: x = 1
|
|
||||||
F3 x = F3::One();
|
|
||||||
// Naive unoptimized number of iterations: 5
|
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
const F3 x3 = Rescale<3>(x * x * x);
|
|
||||||
x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
|
|
||||||
}
|
|
||||||
const F0 fixedpoint_half_sqrt_2 =
|
|
||||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
|
|
||||||
x = x * fixedpoint_half_sqrt_2;
|
|
||||||
*output_inv_sqrt = x.raw();
|
|
||||||
if (*output_shift < 0) {
|
|
||||||
*output_inv_sqrt <<= -*output_shift;
|
|
||||||
*output_shift = 0;
|
|
||||||
}
|
|
||||||
// Convert right shift (right is positive) to left shift.
|
|
||||||
*output_shift *= kReverseShift;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||||
const RuntimeShape& input_shape,
|
const RuntimeShape& input_shape,
|
||||||
const uint8* input_data,
|
const uint8* input_data,
|
||||||
@ -557,9 +508,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
}
|
}
|
||||||
int32 inv_l2norm_multiplier;
|
int32 inv_l2norm_multiplier;
|
||||||
int inv_l2norm_shift;
|
int inv_l2norm_shift;
|
||||||
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
|
GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
|
||||||
&inv_l2norm_shift);
|
&inv_l2norm_multiplier, &inv_l2norm_shift);
|
||||||
|
|
||||||
for (int c = 0; c < depth; c++) {
|
for (int c = 0; c < depth; c++) {
|
||||||
int32 diff = input_data[depth * i + c] - input_zero_point;
|
int32 diff = input_data[depth * i + c] - input_zero_point;
|
||||||
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -45,14 +46,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||||
|
|
||||||
TF_LITE_ENSURE(
|
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
|
||||||
context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8);
|
output->type == kTfLiteUInt8 ||
|
||||||
|
output->type == kTfLiteInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
if (output->type == kTfLiteUInt8) {
|
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
|
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
|
||||||
|
if (output->type == kTfLiteUInt8) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
|
||||||
}
|
}
|
||||||
|
if (output->type == kTfLiteInt8) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(ahentz): For some reason our implementations don't support
|
// TODO(ahentz): For some reason our implementations don't support
|
||||||
// activations.
|
// activations.
|
||||||
@ -97,6 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_L2NORM(optimized_ops);
|
TF_LITE_L2NORM(optimized_ops);
|
||||||
}
|
}
|
||||||
#undef TF_LITE_L2NORM
|
#undef TF_LITE_L2NORM
|
||||||
|
} else if (output->type == kTfLiteInt8) {
|
||||||
|
const auto input_shape = GetTensorShape(input);
|
||||||
|
const auto output_shape = GetTensorShape(output);
|
||||||
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
|
const int depth =
|
||||||
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
const int outer_size =
|
||||||
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
|
reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
|
||||||
|
depth, GetTensorData<int8>(input),
|
||||||
|
GetTensorData<int8>(output));
|
||||||
} else {
|
} else {
|
||||||
context->ReportError(context, "Output type is %d, requires float.",
|
context->ReportError(context, "Output type is %d, requires float.",
|
||||||
output->type);
|
output->type);
|
||||||
|
@ -55,9 +55,10 @@ class L2NormOpModel : public SingleOpModel {
|
|||||||
return ExtractVector<T>(output_);
|
return ExtractVector<T>(output_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
std::vector<float> GetDequantizedOutput() {
|
std::vector<float> GetDequantizedOutput() {
|
||||||
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
|
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
|
||||||
GetScale(output_), GetZeroPoint(output_));
|
GetZeroPoint(output_));
|
||||||
}
|
}
|
||||||
|
|
||||||
int input() const { return input_; }
|
int input() const { return input_; }
|
||||||
@ -100,7 +101,20 @@ TEST(L2NormOpTest, SimpleUint8Test) {
|
|||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||||
ElementsAreArray({58, 166, 173, 205, 83, 134}));
|
ElementsAreArray({58, 166, 173, 205, 83, 134}));
|
||||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||||
|
ElementsAreArray(
|
||||||
|
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(L2NormOpTest, SimpleInt8Test) {
|
||||||
|
L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
|
m.QuantizeAndPopulate<int8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
|
ElementsAreArray({-70, 38, 45, 77, -45, 6}));
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
ElementsAreArray(
|
ElementsAreArray(
|
||||||
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
||||||
}
|
}
|
||||||
@ -121,7 +135,32 @@ TEST(L2NormOpTest, MultipleBatchUint8Test) {
|
|||||||
58, 166, 173, 205, 83, 134, // batch 2
|
58, 166, 173, 205, 83, 134, // batch 2
|
||||||
58, 166, 173, 205, 83, 134, // batch 3
|
58, 166, 173, 205, 83, 134, // batch 3
|
||||||
}));
|
}));
|
||||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
||||||
|
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
|
||||||
|
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
|
||||||
|
},
|
||||||
|
0.1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(L2NormOpTest, MultipleBatchInt8Test) {
|
||||||
|
L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
|
m.QuantizeAndPopulate<int8_t>(m.input(),
|
||||||
|
{
|
||||||
|
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
|
||||||
|
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
|
||||||
|
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
|
||||||
|
-70, 38, 45, 77, -45, 6, // batch 1
|
||||||
|
-70, 38, 45, 77, -45, 6, // batch 2
|
||||||
|
-70, 38, 45, 77, -45, 6, // batch 3
|
||||||
|
}));
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear(
|
ElementsAreArray(ArrayFloatNear(
|
||||||
{
|
{
|
||||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
||||||
|
@ -229,7 +229,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
|
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
|
||||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
|
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
|
||||||
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
||||||
Register_LOCAL_RESPONSE_NORMALIZATION());
|
Register_LOCAL_RESPONSE_NORMALIZATION());
|
||||||
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
||||||
|
@ -618,6 +618,12 @@ class L2Normalization
|
|||||||
}
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& output_name = op_signature.op->outputs[0];
|
||||||
|
const Array& output_array = op_signature.model->GetArray(output_name);
|
||||||
|
// Version 2 supports signed int8 input types.
|
||||||
|
if (output_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -765,6 +765,28 @@ void SimpleVersioningTest() {
|
|||||||
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test version for a simple Op with 2 versions and the output type controls the
|
||||||
|
// version.
|
||||||
|
template <typename Op>
|
||||||
|
void SimpleOutputVersioningTest() {
|
||||||
|
Op op;
|
||||||
|
op.outputs = {"output1"};
|
||||||
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||||
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
||||||
|
|
||||||
|
Model uint8_model;
|
||||||
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
|
||||||
|
uint8_array.data_type = ArrayDataType::kUint8;
|
||||||
|
OperatorSignature uint8_signature = {.model = &uint8_model, .op = &op};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
||||||
|
|
||||||
|
Model int8_model;
|
||||||
|
Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
|
||||||
|
int8_array.data_type = ArrayDataType::kInt8;
|
||||||
|
OperatorSignature int8_signature = {.model = &int8_model, .op = &op};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningEqualTest) {
|
TEST_F(OperatorTest, VersioningEqualTest) {
|
||||||
SimpleVersioningTest<TensorFlowEqualOperator>();
|
SimpleVersioningTest<TensorFlowEqualOperator>();
|
||||||
}
|
}
|
||||||
@ -825,6 +847,10 @@ TEST_F(OperatorTest, VersioningLogisticTest) {
|
|||||||
SimpleVersioningTest<LogisticOperator>();
|
SimpleVersioningTest<LogisticOperator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningL2NormTest) {
|
||||||
|
SimpleOutputVersioningTest<L2NormalizationOperator>();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user