STT-tensorflow/tensorflow/lite/kernels/elementwise.cc
David Rim c7b8cbaf60 Add quantized int8 evaluation for squared_difference and rsqrt
PiperOrigin-RevId: 344354493
Change-Id: I8d0c73afa19944c783c35cf60518dd8b62884fca
2020-11-25 20:09:43 -08:00

370 lines
14 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <stdlib.h>
#include <cmath>
#include <limits>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace elementwise {
namespace {
constexpr char kAbsName[] = "Abs";
constexpr char kSinName[] = "Sin";
constexpr char kCosName[] = "Cos";
constexpr char kLogName[] = "Log";
constexpr char kSqrtName[] = "Sqrt";
constexpr char kRsqrtName[] = "Rsqrt";
constexpr char kSquareName[] = "Square";
constexpr char kNotName[] = "Not";
struct OpData {
int32_t multiplier;
int32_t shift;
int input_offset;
int output_offset;
};
bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}
bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}
bool IsAbsSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
}
bool IsRsqrtSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8;
}
inline void SetAbsOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int32_t* shift) {
QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
}
inline void SetRsqrtOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int32_t* shift) {
const double scale = 1. / (std::sqrt(input_scale) * output_scale);
QuantizeMultiplier(scale, multiplier, shift);
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType is_supported_type, const char* op_name>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!is_supported_type(input->type)) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
}
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
TfLiteTensor* output = GetOutput(context, node, 0);
auto* op_data = static_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, input->quantization.type,
kTfLiteAffineQuantization);
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteAffineQuantization);
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
TF_LITE_ENSURE(context, input_params != nullptr);
TF_LITE_ENSURE(context, input_params->scale != nullptr);
TF_LITE_ENSURE(context, input_params->scale->size > 0);
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
TF_LITE_ENSURE(context, output_params != nullptr);
TF_LITE_ENSURE(context, output_params->scale != nullptr);
TF_LITE_ENSURE(context, output_params->scale->size > 0);
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
op_data->input_offset = input_params->zero_point->data[0];
op_data->output_offset = output_params->zero_point->data[0];
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
}
const float input_scale = input_params->scale->data[0];
const float output_scale = output_params->scale->data[0];
if (op_name == kAbsName) {
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
} else if (op_name == kRsqrtName) {
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
}
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
std::function<TfLiteStatus(T)> validate_input_func,
TfLiteType expected_type) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);
for (int64_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
TfLiteType expected_type) {
return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
expected_type);
}
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
}
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
bool bool_func(bool)) {
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
}
void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
size_t length) {
return new OpData();
}
void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
delete static_cast<OpData*>(buffer);
}
template <typename T>
TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteType type) {
const auto* op_data = static_cast<const OpData*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
std::function<T(T)> func = [&](T i) {
const int32_t value = std::abs(i - op_data->input_offset);
const int32_t output = MultiplyByQuantizedMultiplier(
value, op_data->multiplier, op_data->shift) +
op_data->output_offset;
return static_cast<T>(std::min(std::max(output, kMin), kMax));
};
return EvalImpl<T>(context, node, func, type);
}
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteType type = GetInput(context, node, 0)->type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(context, node, std::abs<float>, type);
case kTfLiteInt8:
return AbsEvalQuantized<int8_t>(context, node, type);
case kTfLiteInt16:
return AbsEvalQuantized<int16_t>(context, node, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::sin);
}
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::cos);
}
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::log);
}
TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::sqrt);
}
TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteType type) {
const auto* op_data = static_cast<const OpData*>(node->user_data);
const int kMin = std::numeric_limits<int8_t>::min();
const int kMax = std::numeric_limits<int8_t>::max();
std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
"Rsqrt is only defined for positive values");
return kTfLiteOk;
};
std::function<int8_t(int8_t)> func = [&](int8_t i) {
const int32_t value = (i - op_data->input_offset);
const int32_t kShift = 20; // Shift to keep value integer.
if (value == 0) {
// Assume that any value close to 0 represents the max output value.
return static_cast<int8_t>(kMax);
}
int32_t inv_sqrt_multiplier;
int inv_sqrt_shift;
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
&inv_sqrt_shift);
const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
inv_sqrt_shift + kShift);
const int32_t output =
MultiplyByQuantizedMultiplier(data, op_data->multiplier,
op_data->shift - kShift) +
op_data->output_offset;
return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
};
return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteType type = GetInput(context, node, 0)->type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(
context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
case kTfLiteInt8:
return RsqrtEvalQuantized(context, node, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return f * f; });
}
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
} // namespace
} // namespace elementwise
TfLiteRegistration* Register_ABS() {
static TfLiteRegistration r = {
elementwise::ElementWiseQuantizedInit,
elementwise::ElementWiseQuantizedFree,
elementwise::GenericPrepare<elementwise::IsAbsSupportedType,
elementwise::kAbsName>,
elementwise::AbsEval};
return &r;
}
TfLiteRegistration* Register_SIN() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSinName>,
elementwise::SinEval};
return &r;
}
TfLiteRegistration* Register_COS() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kCosName>,
elementwise::CosEval};
return &r;
}
TfLiteRegistration* Register_LOG() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kLogName>,
elementwise::LogEval};
return &r;
}
TfLiteRegistration* Register_SQRT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSqrtName>,
elementwise::SqrtEval};
return &r;
}
TfLiteRegistration* Register_RSQRT() {
static TfLiteRegistration r = {
elementwise::ElementWiseQuantizedInit,
elementwise::ElementWiseQuantizedFree,
elementwise::GenericPrepare<elementwise::IsRsqrtSupportedType,
elementwise::kRsqrtName>,
elementwise::RsqrtEval};
return &r;
}
TfLiteRegistration* Register_SQUARE() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSquareName>,
elementwise::SquareEval};
return &r;
}
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType,
elementwise::kNotName>,
elementwise::LogicalNotEval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite