TFL: Port HARD_SWISH operator from TFLite to TFLu
Change-Id: I18092fa30dee33df4577a4abaf5321e7ba96a04c
This commit is contained in:
parent
e5023a1738
commit
daa3c52aa1
tensorflow/lite
kernels
micro
@ -298,22 +298,6 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
|
||||
delete static_cast<HardSwishData*>(buffer);
|
||||
}
|
||||
|
||||
void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,
|
||||
int16_t* multiplier_int16) {
|
||||
TFLITE_DCHECK_GE(multiplier_int32, 0);
|
||||
static constexpr int32_t kRoundingOffset = 1 << 15;
|
||||
if (multiplier_int32 >=
|
||||
std::numeric_limits<int32_t>::max() - kRoundingOffset) {
|
||||
*multiplier_int16 = std::numeric_limits<int16_t>::max();
|
||||
return;
|
||||
}
|
||||
const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
|
||||
TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
|
||||
TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
|
||||
*multiplier_int16 = result;
|
||||
TFLITE_DCHECK_EQ(*multiplier_int16, result);
|
||||
}
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
|
@ -932,6 +932,53 @@ void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Similar to ARM instruction SQDMULH.
|
||||
// Similar to gemmlowp::SaturatingRoundingDoublingHighMul except
|
||||
// rounding to zero instead of to nearest (SQRDMULH).
|
||||
inline std::int16_t SaturatingDoublingHighMul(std::int16_t a, std::int16_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
|
||||
std::int32_t a_32(a);
|
||||
std::int32_t b_32(b);
|
||||
std::int32_t ab_32 = a_32 * b_32;
|
||||
std::int16_t ab_x2_high16 = static_cast<std::int16_t>((ab_32) / (1 << 15));
|
||||
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
|
||||
}
|
||||
|
||||
// Similar to gemmlowp::SaturatingRoundingDoublingHighMul
|
||||
inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
|
||||
std::int32_t a_32(a);
|
||||
std::int32_t b_32(b);
|
||||
std::int32_t ab_32 = a_32 * b_32;
|
||||
std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
|
||||
std::int16_t ab_x2_high16 =
|
||||
static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
|
||||
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
|
||||
}
|
||||
|
||||
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
|
||||
int32_t result = static_cast<int32_t>(value) * (1 << amount);
|
||||
result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
|
||||
result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,
|
||||
int16_t* multiplier_int16) {
|
||||
TFLITE_DCHECK_GE(multiplier_int32, 0);
|
||||
static constexpr int32_t kRoundingOffset = 1 << 15;
|
||||
if (multiplier_int32 >=
|
||||
std::numeric_limits<int32_t>::max() - kRoundingOffset) {
|
||||
*multiplier_int16 = std::numeric_limits<int16_t>::max();
|
||||
return;
|
||||
}
|
||||
const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
|
||||
TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
|
||||
TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
|
||||
*multiplier_int16 = result;
|
||||
TFLITE_DCHECK_EQ(*multiplier_int16, result);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
|
||||
|
@ -2615,25 +2615,6 @@ inline void HardSwish(const RuntimeShape& input_shape, const T* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
|
||||
int32_t result = static_cast<int32_t>(value) * (1 << amount);
|
||||
result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
|
||||
result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Similar to ARM instruction SQDMULH.
|
||||
// Similar to gemmlowp::SaturatingRoundingDoublingHighMul except
|
||||
// rounding to zero instead of to nearest (SQRDMULH).
|
||||
inline std::int16_t SaturatingDoublingHighMul(std::int16_t a, std::int16_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
|
||||
std::int32_t a_32(a);
|
||||
std::int32_t b_32(b);
|
||||
std::int32_t ab_32 = a_32 * b_32;
|
||||
std::int16_t ab_x2_high16 = static_cast<std::int16_t>((ab_32) / (1 << 15));
|
||||
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void HardSwish(const HardSwishParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
|
@ -42,6 +42,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddFullyConnected();
|
||||
AddGreater();
|
||||
AddGreaterEqual();
|
||||
AddHardSwish();
|
||||
AddL2Normalization();
|
||||
AddLess();
|
||||
AddLessEqual();
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
@ -77,6 +78,157 @@ inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
|
||||
}
|
||||
}
|
||||
|
||||
inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) {
|
||||
std::int32_t sign = numerator >= 0 ? 1 : -1;
|
||||
std::int32_t abs_numerator = std::abs(numerator);
|
||||
std::int32_t mask = (1LL << exponent) - 1;
|
||||
std::int32_t remainder = abs_numerator & mask;
|
||||
std::int32_t threshold = mask >> 1;
|
||||
std::int32_t abs_result =
|
||||
(abs_numerator >> exponent) + (remainder > threshold ? 1 : 0);
|
||||
return sign * abs_result;
|
||||
}
|
||||
|
||||
inline void HardSwishFloatOp(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
auto matching_size = MatchingFlatSize(input_shape, output_shape);
|
||||
const float* in_end = input_data + matching_size;
|
||||
for (; input_data < in_end; input_data++, output_data++) {
|
||||
const float in = *input_data;
|
||||
*output_data =
|
||||
in * std::min(static_cast<float>(6), std::max(static_cast<float>(0), in + 3)) /
|
||||
6;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HardSwishOp(HardSwishParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
const int16_t input_value = input_data[i] - params.input_zero_point;
|
||||
// Left-shift as much as we can without overflow/saturation to put
|
||||
// significant bits in the high bits of our 16-bit fixedpoint values, so
|
||||
// that fixed-point approximate computations below are as accurate as
|
||||
// possible.
|
||||
const int16_t input_value_on_hires_input_scale = input_value << 7;
|
||||
// Compute the input value on essentially the output scale, just not
|
||||
// right-shifted yet. This is the value that we'll use in the (x >= +3)
|
||||
// case, and that in the general case we'll multiply against the "relu-ish"
|
||||
// fixed-point multiplier in [0, 1].
|
||||
const int16_t input_value_on_preshift_output_scale =
|
||||
SaturatingRoundingDoublingHighMul(input_value_on_hires_input_scale,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
// Now compute the "relu-ish multiplier". In the (-3 <= x <= +3) case, that
|
||||
// is just an affine rescaling of x from [-3, 3] to [0, 1]. In the general
|
||||
// case, it is just that plus saturation at the boundaries of [-3, 3].
|
||||
// First, we rescale from [-3, 3] to [-1, 1], saturating.
|
||||
// That is done by rescaling the input value with a fixed-point multiplier
|
||||
// (reluish_multiplier_fixedpoint) and bit-shift such that we represent
|
||||
// that input value on the scale where the real value 3.0f is represented
|
||||
// by the quantized value 32768. (+32768 is actually not representable as
|
||||
// int16, so this saturates at +32767, and that is seen empirically to be
|
||||
// a negligible contribution to numerical error/bias).
|
||||
//
|
||||
// This code is careful to correctly implement any magnitude of multiplier,
|
||||
// involving either a right shift or a left shift, with correct saturation
|
||||
// behavior in the left-shift case. This forces this code to be more
|
||||
// complicated, but is necessary for real applications: a partially
|
||||
// trained quantized MobileNet v3-small model that motivated this code
|
||||
// exhibits some large [min, max] range boundaries, of the order of
|
||||
// magnitude of 10 or 100 depending on layers.
|
||||
//
|
||||
// The next few lines are basically just an ordinary
|
||||
// MultiplyByQuantizedMultiplier, except that we are more careful here
|
||||
// about the fine details of saturation when left-shifting, because here
|
||||
// overflow in left-shift is a common case, not an anomaly as
|
||||
// MultiplyByQuantizedMultiplier assumes.
|
||||
int16_t reluish_value = input_value_on_hires_input_scale;
|
||||
// Shift left, saturating, as much as we can while ensuring that this
|
||||
// saturation will not contribute to the result. That is, left shift amount
|
||||
// reduced by 1.
|
||||
if (params.reluish_multiplier_exponent > 0) {
|
||||
reluish_value = SaturatingLeftShift(
|
||||
reluish_value, params.reluish_multiplier_exponent - 1);
|
||||
}
|
||||
// Apply the fixed-point multiplier, dividing the value by a divisor
|
||||
// ranging in [1, 2].
|
||||
reluish_value = SaturatingRoundingDoublingHighMul(reluish_value, params.reluish_multiplier_fixedpoint_int16);
|
||||
// Apply the last bit of left-shift. Thus, in the left-shifting case, if
|
||||
// any saturation affects the result, it is happening here --- any
|
||||
// saturation having occurred above is overwritten here, not affecting the
|
||||
// result.
|
||||
if (params.reluish_multiplier_exponent > 0) {
|
||||
reluish_value = SaturatingLeftShift(reluish_value, 1);
|
||||
}
|
||||
// Shift right, in the right-shifting case.
|
||||
if (params.reluish_multiplier_exponent < 0) {
|
||||
reluish_value = RoundingDivideByPOT(
|
||||
reluish_value, -params.reluish_multiplier_exponent);
|
||||
}
|
||||
// At this point we have rescaled the value into a 16bit fixedpoint
|
||||
// reluish_value in [-1, 1].
|
||||
// We now convert that to a 16bit fixedpoint value in [0, 1].
|
||||
reluish_value = (reluish_value + (1 << 15)) >> 1;
|
||||
// Use of SaturatingDoublingHighMul here is important to cancel the biases
|
||||
// from the above SaturatingRoundingDoublingHighMul.
|
||||
//
|
||||
const int16_t preshift_output_value = SaturatingDoublingHighMul(
|
||||
reluish_value, input_value_on_preshift_output_scale);
|
||||
// We were so far operating on the pre-shift output scale. Now we finally
|
||||
// apply that output shift, arriving at the final output scale.
|
||||
int16_t output_value = RoundingDivideByPOT(
|
||||
preshift_output_value, -params.output_multiplier_exponent);
|
||||
output_value += params.output_zero_point;
|
||||
output_value =
|
||||
std::min<int16_t>(output_value, std::numeric_limits<T>::max());
|
||||
output_value =
|
||||
std::max<int16_t>(output_value, std::numeric_limits<T>::min());
|
||||
output_data[i] = output_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Q>
|
||||
TfLiteStatus HardSwishQuantized(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
HardSwishParams params;
|
||||
|
||||
params.input_zero_point = input->params.zero_point;
|
||||
params.output_zero_point = output->params.zero_point;
|
||||
|
||||
const float input_scale = input->params.scale;
|
||||
const float hires_input_scale = (1.0f / 128.0f) * input_scale;
|
||||
const float reluish_scale = 3.0f / 32768.0f;
|
||||
const float output_scale = output->params.scale;
|
||||
|
||||
const double output_multiplier = static_cast<double>(hires_input_scale / output_scale);
|
||||
int32_t output_multiplier_fixedpoint_int32;
|
||||
QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
|
||||
¶ms.output_multiplier_exponent);
|
||||
DownScaleInt32ToInt16Multiplier(
|
||||
output_multiplier_fixedpoint_int32,
|
||||
¶ms.output_multiplier_fixedpoint_int16);
|
||||
|
||||
TF_LITE_ENSURE(context, params.output_multiplier_exponent <= 0);
|
||||
|
||||
const double reluish_multiplier = static_cast<double>(hires_input_scale / reluish_scale);
|
||||
int32_t reluish_multiplier_fixedpoint_int32;
|
||||
QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
|
||||
¶ms.reluish_multiplier_exponent);
|
||||
DownScaleInt32ToInt16Multiplier(
|
||||
reluish_multiplier_fixedpoint_int32,
|
||||
¶ms.reluish_multiplier_fixedpoint_int16);
|
||||
|
||||
HardSwishOp<Q>(params, GetTensorShape(input),
|
||||
GetTensorData<Q>(input), GetTensorShape(output), GetTensorData<Q>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -107,7 +259,7 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32/int8/uint8 are supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -148,13 +300,43 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32/int8/uint8 are supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
HardSwishFloatOp(
|
||||
GetTensorShape(input),
|
||||
GetTensorData<float>(input),
|
||||
GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteUInt8: {
|
||||
return HardSwishQuantized<uint8>(context, node);
|
||||
} break;
|
||||
case kTfLiteInt8: {
|
||||
return HardSwishQuantized<int8>(context, node);
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Only float32/int8/uint8 are supported currently, got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration* Register_RELU() {
|
||||
@ -181,6 +363,13 @@ TfLiteRegistration* Register_RELU6() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_HARD_SWISH() {
|
||||
static TfLiteRegistration r = {};
|
||||
r.prepare = activations::HardSwishPrepare;
|
||||
r.invoke = activations::HardSwishEval;
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
@ -135,6 +137,318 @@ void TestRelu6Float(const int* input_dims_data, const float* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateUniformRandomVector(int size, float min, float max,
|
||||
std::minstd_rand* random_engine,
|
||||
std::vector<float>* result) {
|
||||
// Never use std::uniform_*_distribution in tests, it's
|
||||
// implementation-defined. Likewise, don't use std::default_random_engine,
|
||||
// implementation-defined. Implementation-defined is bad because it means that
|
||||
// any toolchain update or new platform may run into test failures.
|
||||
// std::minstd_rand is a standard instantiation of
|
||||
// std::linear_congruential_engine, the cheapest generator in c++11 stdlib,
|
||||
// it's good enough here.
|
||||
result->resize(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
// We don't care whether the `max` value may ever be produced exactly.
|
||||
// It may actually be thanks to rounding, as std::minstd_rand::modulus
|
||||
// is 2^31 - 1 is greater than the inverse float epsilon.
|
||||
float random_value_scaled_0_1 =
|
||||
(*random_engine)() *
|
||||
(1.0f / static_cast<float>(std::minstd_rand::modulus));
|
||||
(*result)[i] = min + (max - min) * random_value_scaled_0_1;
|
||||
}
|
||||
}
|
||||
|
||||
void EvalTestReferenceHardSwish(int size, const std::vector<float>& input,
|
||||
std::vector<float>* result) {
|
||||
result->resize(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
const float in = input[i];
|
||||
(*result)[i] = in * std::min(6.0f, std::max(0.0f, in + 3)) * (1.0f / 6.0f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestHardSwishQuantized(int size, float input_min,
|
||||
float input_max, float output_min,
|
||||
float output_max, std::minstd_rand* random_engine) {
|
||||
T output_data[size];
|
||||
T input_data_quantized[size];
|
||||
const int input_dims_data[] = {2, 1, size};
|
||||
const int output_dims_data[] = {2, 1, size};
|
||||
const float input_scale = ScaleFromMinMax<T>(input_min, input_max);
|
||||
const int input_zero_point = ZeroPointFromMinMax<T>(input_min, input_max);
|
||||
const float output_scale = ScaleFromMinMax<T>(output_min, output_max);
|
||||
const int output_zero_point = ZeroPointFromMinMax<T>(output_min, output_max);
|
||||
|
||||
// The numerical error for any 8bit quantized function is at least one half
|
||||
// times the quantization step: 0.5 * (kOutMax - kOutMin) / 256.
|
||||
// To that we add again the quantization step (kOutMax - kOutMin) / 256
|
||||
// to allow for an off-by-one rounding error.
|
||||
const float kTolerance = std::max(input_max - input_min, output_max - output_min) * (1.5f / 256.f);
|
||||
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
const int output_elements_count = ElementCount(*output_dims);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(output_elements_count, size);
|
||||
|
||||
float dequantized_output[output_elements_count];
|
||||
|
||||
std::vector<float> float_input_values;
|
||||
std::vector<float> float_ref_output_values;
|
||||
GenerateUniformRandomVector(size, input_min, input_max, random_engine,
|
||||
&float_input_values);
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
for (float& val : float_ref_output_values) {
|
||||
val = std::min(output_max, std::max(output_min, val));
|
||||
}
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateQuantizedTensor(float_input_values.data(), input_data_quantized, input_dims,
|
||||
input_scale, input_zero_point, "input_tensor"),
|
||||
CreateQuantizedTensor(output_data, output_dims, output_scale,
|
||||
output_zero_point, "output_tensor"),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_HARD_SWISH);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
const char* init_data = nullptr;
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, init_data, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
|
||||
AsymmetricDequantize<T>(output_data, output_elements_count, output_scale, output_zero_point, dequantized_output);
|
||||
|
||||
for (int i = 0; i < output_elements_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(float_ref_output_values.data()[i], dequantized_output[i], kTolerance);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestHardSwishQuantizedBias(float input_min, float input_max, float output_min,
|
||||
float output_max, float tolerated_bias) {
|
||||
const float quantized_type_range =
|
||||
static_cast<float>(std::numeric_limits<T>::max()) -
|
||||
static_cast<float>(std::numeric_limits<T>::min());
|
||||
|
||||
|
||||
const float input_scale = ScaleFromMinMax<T>(input_min, input_max);
|
||||
const float output_scale = ScaleFromMinMax<T>(output_min, output_max);
|
||||
|
||||
const int input_zero_point = ZeroPointFromMinMax<T>(input_min, input_max);
|
||||
const int output_zero_point = ZeroPointFromMinMax<T>(output_min, output_max);
|
||||
|
||||
const float max_scale = std::max(output_scale, input_scale);
|
||||
|
||||
// In this bias-focused test case, no need for randomly generated input
|
||||
// values.
|
||||
TF_LITE_MICRO_EXPECT_LE(input_min, -3.0f);
|
||||
TF_LITE_MICRO_EXPECT_GE(input_max, 3.0f);
|
||||
const int quantized_input_negative_three =
|
||||
std::round(std::numeric_limits<T>::min() +
|
||||
(-3.0f - input_min) / input_scale);
|
||||
const int quantized_input_positive_three =
|
||||
std::round(std::numeric_limits<T>::min() +
|
||||
(3.0f - input_min) / input_scale);
|
||||
std::vector<float> float_input_values;
|
||||
for (int i = quantized_input_negative_three;
|
||||
i <= quantized_input_positive_three; i++) {
|
||||
float_input_values.push_back(
|
||||
input_min +
|
||||
(i - std::numeric_limits<T>::min()) * input_scale);
|
||||
}
|
||||
const int size = float_input_values.size();
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
for (float& val : float_ref_output_values) {
|
||||
val = std::min(output_max, std::max(output_min, val));
|
||||
}
|
||||
|
||||
T output_data[size];
|
||||
T input_data_quantized[size];
|
||||
const int input_dims_data[] = {2, 1, size};
|
||||
const int output_dims_data[] = {2, 1, size};
|
||||
|
||||
// The numerical error for any 8bit quantized function is at least one half
|
||||
// times the quantization step: 0.5 * (kOutMax - kOutMin) / 256.
|
||||
// To that we add again the quantization step (kOutMax - kOutMin) / 256
|
||||
// to allow for an off-by-one rounding error.
|
||||
const float kTolerance = std::max(input_max - input_min, output_max - output_min) * (1.5f / 256.f);
|
||||
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
const int output_elements_count = ElementCount(*output_dims);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(output_elements_count, size);
|
||||
|
||||
float dequantized_output[output_elements_count];
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateQuantizedTensor(float_input_values.data(), input_data_quantized, input_dims,
|
||||
input_scale, input_zero_point, "input_tensor"),
|
||||
CreateQuantizedTensor(output_data, output_dims, output_scale,
|
||||
output_zero_point, "output_tensor"),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_HARD_SWISH);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
const char* init_data = nullptr;
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, init_data, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
|
||||
AsymmetricDequantize<T>(output_data, output_elements_count, output_scale, output_zero_point, dequantized_output);
|
||||
|
||||
float sum_diff = 0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
sum_diff += dequantized_output[i] - float_ref_output_values[i];
|
||||
}
|
||||
const float bias = sum_diff / (size * max_scale);
|
||||
TF_LITE_MICRO_EXPECT_LE(std::abs(bias), tolerated_bias);
|
||||
}
|
||||
|
||||
void TestHardSwishFloat(int size, std::minstd_rand* random_engine) {
|
||||
std::vector<float> float_input_values;
|
||||
const float kMin = -10.0f;
|
||||
const float kMax = 10.0f;
|
||||
GenerateUniformRandomVector(size, kMin, kMax, random_engine,
|
||||
&float_input_values);
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
|
||||
float output_data[size];
|
||||
const int input_dims_data[] = {1, size};
|
||||
const int output_dims_data[] = {1, size};
|
||||
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
const int output_elements_count = ElementCount(*output_dims);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(output_elements_count, size);
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateFloatTensor(float_input_values.data(), input_dims, "input_tensor"),
|
||||
CreateFloatTensor(output_data, output_dims, "output_tensor"),
|
||||
};
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_HARD_SWISH);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
const char* init_data = nullptr;
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, init_data, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
|
||||
for (int i = 0; i < output_elements_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(float_ref_output_values.data()[i], output_data[i], 1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestReluUint8(const int* input_dims_data, const float* input_data,
|
||||
uint8_t* input_data_quantized, const float input_scale,
|
||||
const int input_zero_point, const float* golden,
|
||||
@ -431,6 +745,46 @@ TF_LITE_MICRO_TEST(SimpleRelu6TestFloat) {
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleHardSwishTestFloat) {
|
||||
std::minstd_rand random_engine;
|
||||
for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) {
|
||||
tflite::testing::TestHardSwishFloat(size, &random_engine);
|
||||
}
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleHardSwishTestQuantized) {
|
||||
std::minstd_rand random_engine;
|
||||
std::vector<std::pair<float, float>> minmax_pairs{
|
||||
{0.f, 1.f}, {-2.f, 1.f}, {-5.f, 10.f}, {-40.f, 60.f}};
|
||||
for (const auto& input_minmax : minmax_pairs) {
|
||||
for (const auto& output_minmax : minmax_pairs) {
|
||||
float input_min = input_minmax.first;
|
||||
float input_max = input_minmax.second;
|
||||
float output_min = output_minmax.first;
|
||||
float output_max = output_minmax.second;
|
||||
for (int size : {1, 3, 10, 100}) {
|
||||
tflite::testing::TestHardSwishQuantized<int8_t>(size, input_min, input_max,
|
||||
output_min, output_max,
|
||||
&random_engine);
|
||||
tflite::testing::TestHardSwishQuantized<uint8_t>(size, input_min, input_max,
|
||||
output_min, output_max,
|
||||
&random_engine);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment in the reference implementation of quantized HardSwish:
|
||||
// A numerical issue significantly affecting ImageNet classification accuracy
|
||||
// with MobileNet v3 is only observable at the scale of HardSwish unit tests
|
||||
// if we monitor specifically bias. This testcase is extracted from one of the
|
||||
// HardSwish nodes in that MobileNet v3 that exhibited this issue.
|
||||
TF_LITE_MICRO_TEST(SimpleHardSwishTestQuantizedBias) {
|
||||
tflite::testing::TestHardSwishQuantizedBias<uint8_t>(-11.654928f, 25.036512f,
|
||||
-0.3905796f, 24.50887f, 0.035);
|
||||
}
|
||||
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleReluTestUint8) {
|
||||
const int elements_count = 10;
|
||||
|
||||
|
@ -46,6 +46,7 @@ TfLiteRegistration* Register_FLOOR();
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED();
|
||||
TfLiteRegistration* Register_GREATER();
|
||||
TfLiteRegistration* Register_GREATER_EQUAL();
|
||||
TfLiteRegistration* Register_HARD_SWISH();
|
||||
TfLiteRegistration* Register_LESS();
|
||||
TfLiteRegistration* Register_LESS_EQUAL();
|
||||
TfLiteRegistration* Register_LOG();
|
||||
|
@ -217,6 +217,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddHardSwish() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_HARD_SWISH,
|
||||
*tflite::ops::micro::Register_HARD_SWISH(),
|
||||
ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddL2Normalization() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
|
@ -94,6 +94,16 @@ void SymmetricDequantize(const int8_t* values, const int size,
|
||||
const float dequantization_scale,
|
||||
float* dequantized_values);
|
||||
|
||||
template <typename T>
|
||||
void AsymmetricDequantize(const T* values, const int size,
|
||||
const float dequantization_scale,
|
||||
int dequantization_zero_point,
|
||||
float* dequantized_values) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
dequantized_values[i] = (values[i] - dequantization_zero_point) * dequantization_scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_MICRO_UTILS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user