Add quantized int8 evaluation for squared_difference and rsqrt

PiperOrigin-RevId: 344354493
Change-Id: I8d0c73afa19944c783c35cf60518dd8b62884fca
This commit is contained in:
David Rim 2020-11-25 20:03:26 -08:00 committed by TensorFlower Gardener
parent 24960e01b8
commit c7b8cbaf60
13 changed files with 526 additions and 103 deletions

View File

@ -61,6 +61,23 @@ bool IsAbsSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
}
bool IsRsqrtSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8;
}
inline void SetAbsOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int32_t* shift) {
QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
}
inline void SetRsqrtOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int32_t* shift) {
const double scale = 1. / (std::sqrt(input_scale) * output_scale);
QuantizeMultiplier(scale, multiplier, shift);
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType is_supported_type, const char* op_name>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
@ -74,15 +91,6 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
if (!is_supported_type(input->type)) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
TfLiteStatus AbsPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(
context, (GenericPrepare<IsAbsSupportedType, kAbsName>(context, node)),
kTfLiteOk);
const TfLiteTensor* input = GetInput(context, node, 0);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
TfLiteTensor* output = GetOutput(context, node, 0);
auto* op_data = static_cast<OpData*>(node->user_data);
@ -110,15 +118,22 @@ TfLiteStatus AbsPrepare(TfLiteContext* context, TfLiteNode* node) {
}
const float input_scale = input_params->scale->data[0];
const float output_scale = output_params->scale->data[0];
double scale = input_scale / output_scale;
QuantizeMultiplier(scale, &op_data->multiplier, &op_data->shift);
if (op_name == kAbsName) {
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
} else if (op_name == kRsqrtName) {
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
}
}
return kTfLiteOk;
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
std::function<TfLiteStatus(T)> validate_input_func,
TfLiteType expected_type) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
@ -129,11 +144,22 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);
for (int64_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
TfLiteType expected_type) {
return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
expected_type);
}
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
@ -144,11 +170,12 @@ inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
}
void* AbsInit(TfLiteContext* context, const char* buffer, size_t length) {
void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
size_t length) {
return new OpData();
}
void AbsFree(TfLiteContext* context, void* buffer) {
void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
delete static_cast<OpData*>(buffer);
}
@ -203,8 +230,53 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::sqrt);
}
TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteType type) {
const auto* op_data = static_cast<const OpData*>(node->user_data);
const int kMin = std::numeric_limits<int8_t>::min();
const int kMax = std::numeric_limits<int8_t>::max();
std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
"Rsqrt is only defined for positive values");
return kTfLiteOk;
};
std::function<int8_t(int8_t)> func = [&](int8_t i) {
const int32_t value = (i - op_data->input_offset);
const int32_t kShift = 20; // Shift to keep value integer.
if (value == 0) {
// Assume that any value close to 0 represents the max output value.
return static_cast<int8_t>(kMax);
}
int32_t inv_sqrt_multiplier;
int inv_sqrt_shift;
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
&inv_sqrt_shift);
const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
inv_sqrt_shift + kShift);
const int32_t output =
MultiplyByQuantizedMultiplier(data, op_data->multiplier,
op_data->shift - kShift) +
op_data->output_offset;
return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
};
return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
const TfLiteType type = GetInput(context, node, 0)->type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(
context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
case kTfLiteInt8:
return RsqrtEvalQuantized(context, node, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@ -219,8 +291,12 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace elementwise
TfLiteRegistration* Register_ABS() {
static TfLiteRegistration r = {elementwise::AbsInit, elementwise::AbsFree,
elementwise::AbsPrepare, elementwise::AbsEval};
static TfLiteRegistration r = {
elementwise::ElementWiseQuantizedInit,
elementwise::ElementWiseQuantizedFree,
elementwise::GenericPrepare<elementwise::IsAbsSupportedType,
elementwise::kAbsName>,
elementwise::AbsEval};
return &r;
}
@ -262,8 +338,9 @@ TfLiteRegistration* Register_SQRT() {
TfLiteRegistration* Register_RSQRT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::ElementWiseQuantizedInit,
elementwise::ElementWiseQuantizedFree,
elementwise::GenericPrepare<elementwise::IsRsqrtSupportedType,
elementwise::kRsqrtName>,
elementwise::RsqrtEval};
return &r;

View File

@ -225,6 +225,107 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
TEST(ElementWise, RsqrtInt8) {
std::vector<float> data = {15., 46., 78., 142., 1., 17., 49., 113.};
std::vector<float> rsqrt_data(data.size());
for (int i = 0; i < rsqrt_data.size(); i++) {
rsqrt_data[i] = 1.f / std::sqrt(data[i]);
}
float kInputScale = 142.0 / 255.0;
float kOutputScale = 1.0 / 255.0;
int32_t zero_point = -128;
ElementWiseOpQuantizedModel m(BuiltinOperator_RSQRT,
{TensorType_INT8,
{1, 8},
0,
142.0,
kInputScale,
zero_point,
true,
{kInputScale},
{zero_point}},
{TensorType_INT8,
{1, 8},
0,
1.0,
kOutputScale,
zero_point,
true,
{kOutputScale},
{zero_point}});
m.QuantizeAndPopulate<int8_t>(m.input(), data);
m.Invoke();
EXPECT_THAT(m.ExtractDequantVector<int8_t>(m.output()),
ElementsAreArray(ArrayFloatNear(rsqrt_data, kInputScale)));
}
TEST(ElementWise, RsqrtCloseTo0Int8) {
std::vector<float> data = {15., 46., 78., 142., 0.1, 1., 49., 113.};
std::vector<float> rsqrt_data(data.size());
for (int i = 0; i < rsqrt_data.size(); i++) {
rsqrt_data[i] = 1.f / std::sqrt(data[i]);
}
float kInputScale = 142.0 / 255.0;
float kOutputScale = 3.16 / 255.0;
int32_t zero_point = -128;
ElementWiseOpQuantizedModel m(BuiltinOperator_RSQRT,
{TensorType_INT8,
{1, 8},
0,
142.0,
kInputScale,
zero_point,
true,
{kInputScale},
{zero_point}},
{TensorType_INT8,
{1, 8},
0,
3.16,
kOutputScale,
zero_point,
true,
{kOutputScale},
{zero_point}});
m.QuantizeAndPopulate<int8_t>(m.input(), data);
m.Invoke();
EXPECT_THAT(m.ExtractDequantVector<int8_t>(m.output()),
ElementsAreArray(ArrayFloatNear(rsqrt_data, kInputScale)));
}
TEST(ElementWise, RsqrtNanInt8) {
std::vector<float> data = {15., 46., 78., 142., 1., 17., -49., 113.};
std::vector<float> rsqrt_data(data.size());
for (int i = 0; i < rsqrt_data.size(); i++) {
rsqrt_data[i] = 1.f / std::sqrt(data[i]);
}
float kInputScale = 142.0 / 127.0;
float kOutputScale = 1.0 / 255.0;
int32_t input_zero_point = 0;
int32_t output_zero_point = -128;
ElementWiseOpQuantizedModel m(BuiltinOperator_RSQRT,
{TensorType_INT8,
{1, 8},
0,
142.0,
kInputScale,
input_zero_point,
true,
{kInputScale},
{input_zero_point}},
{TensorType_INT8,
{1, 8},
0,
1.0,
kOutputScale,
output_zero_point,
true,
{kOutputScale},
{output_zero_point}});
m.QuantizeAndPopulate<int8_t>(m.input(), data);
EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError);
}
TEST(ElementWise, Square) {
ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});

View File

@ -34,55 +34,24 @@ inline void CheckArithmeticParams(const ArithmeticParams& params) {
TFLITE_DCHECK_LE(-params.input2_offset, std::numeric_limits<int8_t>::max());
}
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
inline void AddElementwise(int size, const ArithmeticParams& params,
const int8_t* input1_data, const int8_t* input2_data,
int8_t* output_data) {
inline void ElementWise(
int size, const ArithmeticParams& params, const int8_t* input1_data,
const int8_t* input2_data, int8_t* output_data,
void (*check_arithmetic_params)(const ArithmeticParams&),
int8_t (*binary_func)(int8_t, int8_t, const ArithmeticParams&)) {
CheckArithmeticParams(params);
for (int i = 0; i < size; ++i) {
const int32_t input1_val = params.input1_offset + input1_data[i];
const int32_t input2_val = params.input2_offset + input2_data[i];
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32_t scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
raw_sum, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
output_data[i] = static_cast<int8_t>(clamped_output);
output_data[i] = binary_func(input1_data[i], input2_data[i], params);
}
}
inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int8_t* input1_data,
const RuntimeShape& input2_shape, const int8_t* input2_data,
const RuntimeShape& output_shape, int8_t* output_data) {
CheckArithmeticParams(params);
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8_t* input1_data,
const RuntimeShape& input2_shape,
const int8_t* input2_data,
const RuntimeShape& output_shape,
int8_t* output_data) {
inline void BroadcastBinaryFunction4DSlow(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const int8_t* input1_data, const RuntimeShape& input2_shape,
const int8_t* input2_data, const RuntimeShape& output_shape,
int8_t* output_data,
void (*check_arithmetic_params)(const ArithmeticParams&),
int8_t (*binary_func)(int8_t, int8_t, const ArithmeticParams&)) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@ -105,40 +74,70 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32_t input1_val =
params.input1_offset +
input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32_t input2_val =
params.input2_offset +
input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32_t shifted_input1_val =
input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val =
input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier,
params.input1_shift);
const int32_t scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier,
params.input2_shift);
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
raw_sum, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<int8_t>(clamped_output);
output_data[Offset(extended_output_shape, b, y, x, c)] = binary_func(
input1_data[SubscriptToIndex(desc1, b, y, x, c)],
input2_data[SubscriptToIndex(desc2, b, y, x, c)], params);
}
}
}
}
}
inline int8_t AddFunc(int8_t x, int8_t y, const ArithmeticParams& params) {
const int32_t input1_val = params.input1_offset + x;
const int32_t input2_val = params.input2_offset + y;
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32_t scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
raw_sum, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
return static_cast<int8_t>(clamped_output);
}
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
inline void AddElementwise(int size, const ArithmeticParams& params,
const int8_t* input1_data, const int8_t* input2_data,
int8_t* output_data) {
ElementWise(size, params, input1_data, input2_data, output_data,
CheckArithmeticParams, AddFunc);
}
inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int8_t* input1_data,
const RuntimeShape& input2_shape, const int8_t* input2_data,
const RuntimeShape& output_shape, int8_t* output_data) {
CheckArithmeticParams(params);
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8_t* input1_data,
const RuntimeShape& input2_shape,
const int8_t* input2_data,
const RuntimeShape& output_shape,
int8_t* output_data) {
BroadcastBinaryFunction4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data,
CheckArithmeticParams, AddFunc);
}
} // namespace reference_integer_ops
} // namespace tflite

View File

@ -236,7 +236,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_RANK, Register_RANK());
AddBuiltin(BuiltinOperator_POW, Register_POW());
@ -261,7 +263,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -397,7 +397,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_RANK, Register_RANK());
AddBuiltin(BuiltinOperator_POW, Register_POW());
@ -424,7 +426,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@ -36,6 +37,7 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
ArithmeticParams arithmetic_params;
};
template <typename T>
@ -73,6 +75,60 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
// Ensure the quantization parameters are equivalent.
if (input1->type == kTfLiteInt8) {
const auto& input1_quantization_params = input1->params;
const auto& input2_quantization_params = input2->params;
const auto& output_quantization_params = output->params;
const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
TF_LITE_ENSURE(context,
input1_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
input1_quantization_params.zero_point <= integer_type_max);
TF_LITE_ENSURE(context,
input2_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
input2_quantization_params.zero_point <= integer_type_max);
TF_LITE_ENSURE(context,
output_quantization_params.zero_point >= integer_type_min);
TF_LITE_ENSURE(context,
output_quantization_params.zero_point <= integer_type_max);
data->arithmetic_params.input1_offset =
-input1_quantization_params.zero_point;
data->arithmetic_params.input2_offset =
-input2_quantization_params.zero_point;
data->arithmetic_params.output_offset =
output_quantization_params.zero_point;
// shift to make integer for scales.
data->arithmetic_params.left_shift = 7;
const double twice_max_input_scale =
2 * std::max(input1_quantization_params.scale,
input2_quantization_params.scale);
const double real_input1_multiplier =
input1_quantization_params.scale / twice_max_input_scale;
double real_input2_multiplier =
input2_quantization_params.scale / twice_max_input_scale;
const double real_output_multiplier =
(twice_max_input_scale * twice_max_input_scale) /
((1 << data->arithmetic_params.left_shift * 2) *
output_quantization_params.scale);
tflite::QuantizeMultiplierSmallerThanOneExp(
real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
&data->arithmetic_params.input1_shift);
tflite::QuantizeMultiplierSmallerThanOneExp(
real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
&data->arithmetic_params.input2_shift);
tflite::QuantizeMultiplierSmallerThanOneExp(
real_output_multiplier, &data->arithmetic_params.output_multiplier,
&data->arithmetic_params.output_shift);
data->arithmetic_params.quantized_activation_min =
std::numeric_limits<int8_t>::min();
data->arithmetic_params.quantized_activation_max =
std::numeric_limits<int8_t>::max();
}
data->requires_broadcast = !HaveSameShapes(input1, input2);
TfLiteIntArray* output_size = nullptr;
@ -86,6 +142,55 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
inline int8_t SquaredDifference(int8_t x, int8_t y,
const ArithmeticParams& params) {
const int32_t input1_val = params.input1_offset + x;
const int32_t input2_val = params.input2_offset + y;
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32_t scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
// Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
const int32_t squared_raw_diff = raw_diff * raw_diff;
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
squared_raw_diff, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
return static_cast<int8_t>(clamped_output);
}
template <typename T>
void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
const OpData* data,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
TfLiteTensor* output) {
const auto* op_data = static_cast<const OpData*>(node->user_data);
if (data->requires_broadcast) {
reference_integer_ops::BroadcastBinaryFunction4DSlow(
op_data->arithmetic_params, GetTensorShape(input1),
GetTensorData<T>(input1), GetTensorShape(input2),
GetTensorData<T>(input2), GetTensorShape(output),
GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams,
SquaredDifference);
} else {
const int flat_size = GetTensorShape(input1).FlatSize();
reference_integer_ops::ElementWise(
flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1),
GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output),
reference_integer_ops::CheckArithmeticParams, SquaredDifference);
}
}
template <typename T>
void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
const OpData* data, const TfLiteTensor* input1,
@ -121,6 +226,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
EvalSquaredDifference<float>(context, node, data, input1, input2, output);
} else if (output->type == kTfLiteInt32) {
EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
} else if (output->type == kTfLiteInt8) {
EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
output);
} else {
context->ReportError(
context,

View File

@ -64,6 +64,22 @@ class IntegerSquaredDifferenceOpModel : public BaseSquaredDifferenceOpModel {
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
};
float GetTolerance(int min, int max) {
float kQuantizedStep = (max - min) / 255.0;
return kQuantizedStep;
}
class QuantizedSquaredDifferenceOpModel : public BaseSquaredDifferenceOpModel {
public:
using BaseSquaredDifferenceOpModel::BaseSquaredDifferenceOpModel;
template <typename integer_dtype>
std::vector<float> GetDequantizedOutput() {
return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
GetZeroPoint(output_));
}
};
TEST(FloatSquaredDifferenceOpTest, FloatType_SameShape) {
FloatSquaredDifferenceOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@ -151,5 +167,57 @@ TEST(IntegerSquaredDifferenceOpTest, IntegerType_WithBroadcast) {
}
}
TEST(QuantizedSquaredDifferenceOpTest, Quantized_SameShape) {
float kQuantizedTolerance = GetTolerance(0, 1);
QuantizedSquaredDifferenceOpModel m(
{TensorType_INT8, {1, 2, 2, 1}, -1.2, 0.8},
{TensorType_INT8, {1, 2, 2, 1}, -1.5, 0.5},
{TensorType_INT8, {}, 0.0, 0.5});
m.QuantizeAndPopulate<int8_t>(m.input1(), {-0.2, 0.2, -1.2, 0.8});
m.QuantizeAndPopulate<int8_t>(m.input2(), {0.5, 0.2, -1.5, 0.5});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({0.49, 0.0, 0.09, 0.09},
kQuantizedTolerance)));
}
TEST(QuantizedSquaredDifferenceOpTest, Quantized_VariousInputShapes) {
float kQuantizedTolerance = GetTolerance(0, 9);
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
QuantizedSquaredDifferenceOpModel m(
{TensorType_INT8, test_shapes[i], -2.0, 1.7},
{TensorType_INT8, test_shapes[i], -1.0, 1.0},
{TensorType_INT8, {}, 0.0, 9.0});
m.QuantizeAndPopulate<int8_t>(m.input1(), {-2.0, 0.2, 0.3, 0.8, 1.1, -2.0});
m.QuantizeAndPopulate<int8_t>(m.input2(), {1.0, 0.2, 0.6, 0.4, -1.0, -0.0});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{9.0, 0.0, 0.09, 0.16, 4.41, 4.0}, kQuantizedTolerance)))
<< "With shape number " << i;
}
}
TEST(QuantizedSquaredDifferenceOpTest, Quantized_WithBroadcast) {
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
float kQuantizedTolerance = GetTolerance(0, 1);
for (int i = 0; i < test_shapes.size(); ++i) {
QuantizedSquaredDifferenceOpModel m(
{TensorType_INT8, test_shapes[i], -0.2, 1.1},
{TensorType_INT8, {}, 0.0, 0.1}, {TensorType_INT8, {}, 0.0, 1.0});
m.QuantizeAndPopulate<int8_t>(m.input1(), {-0.2, 0.2, 0.5, 0.8, 0.11, 1.1});
m.QuantizeAndPopulate<int8_t>(m.input2(), {0.1});
m.Invoke();
EXPECT_THAT(
m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({0.09, 0.01, 0.16, 0.49, 0.0001, 1.0},
kQuantizedTolerance)))
<< "With shape number " << i;
}
}
} // namespace
} // namespace tflite

View File

@ -318,4 +318,5 @@ def make_floor_mod_tests(options):
@register_make_test_function()
def make_squared_difference_tests(options):
make_binary_op_tests(options, tf.math.squared_difference)
make_binary_op_tests(options, tf.math.squared_difference,
allow_fully_quantize=True)

View File

@ -23,15 +23,32 @@ from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
def _make_elementwise_tests(op):
def _make_elementwise_tests(op, allow_fully_quantize=False, min_value=-100,
max_value=100):
"""Make a set of tests to do element-wise operations."""
def f(options):
"""Actual function that generates examples."""
test_parameters = [{
"input_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
test_parameters = [
{
"input_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
"fully_quantize": [False],
"input_range": [[min_value, max_value]],
},
{
"input_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
"fully_quantize": [True],
"input_range": [[min_value, max_value]],
},
]
if not allow_fully_quantize:
test_parameters = [
test_parameter for test_parameter in test_parameters
if True not in test_parameter["fully_quantize"]
]
def build_graph(parameters):
"""Build the unary op testing graph."""
@ -44,7 +61,9 @@ def _make_elementwise_tests(op):
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
parameters["input_shape"],
min_value=min_value,
max_value=max_value)
return [input_value], sess.run(
outputs, feed_dict={inputs[0]: input_value})
@ -74,7 +93,8 @@ def make_sqrt_tests(options):
@register_make_test_function()
def make_rsqrt_tests(options):
"""Make a set of tests to do 1/sqrt."""
return _make_elementwise_tests(tf.math.rsqrt)(options)
return _make_elementwise_tests(tf.math.rsqrt, allow_fully_quantize=True,
min_value=.1, max_value=1)(options)
@register_make_test_function()

View File

@ -77,6 +77,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
OperatorProperty property;
switch (op_code) {
case BuiltinOperator_ABS:
case BuiltinOperator_RSQRT:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.version = 2;
@ -921,6 +922,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.restrict_same_input_output_scale = true;
property.version = 2;
break;
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_SUB:
property.inputs = {{0, {}}, {1, {}}};
property.outputs = {{0, {}}};

View File

@ -618,6 +618,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
case BuiltinOperator_SELECT:
case BuiltinOperator_RSQRT:
case BuiltinOperator_SQUARED_DIFFERENCE:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}

View File

@ -825,4 +825,39 @@ TEST(OpVersionTest, VersioningBatchMatMulTest) {
fake_op_sig.options.input_quantization.asymmetric_quantize_inputs = true;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
}
TEST(OpVersionTest, VersioningSquaredDifferenceTest) {
// Default.
OpSignature fake_op_sig = {
.op = BuiltinOperator_SQUARED_DIFFERENCE,
.input_types =
std::vector<TensorType>{TensorType_FLOAT32, TensorType_FLOAT32},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
// int8 input is version 2.
fake_op_sig = {
.op = BuiltinOperator_SQUARED_DIFFERENCE,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
TEST(OpVersionTest, VersioningRsqrtTest) {
// Default.
OpSignature fake_op_sig = {
.op = BuiltinOperator_RSQRT,
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
// int8 input is version 2.
fake_op_sig = {
.op = BuiltinOperator_RSQRT,
.input_types = std::vector<TensorType>{TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
} // namespace tflite

View File

@ -258,6 +258,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_LOG_SOFTMAX, 2}, "1.14.0"},
{{BuiltinOperator_LSH_PROJECTION, 1}, "1.5.0"},
{{BuiltinOperator_SQUARED_DIFFERENCE, 1}, "1.13.1"},
{{BuiltinOperator_SQUARED_DIFFERENCE, 2}, kPendingReleaseVersion},
{{BuiltinOperator_MIRROR_PAD, 1}, "1.13.1"},
{{BuiltinOperator_MIRROR_PAD, 2}, "2.3.0"},
{{BuiltinOperator_UNIQUE, 1}, "1.14.0"},
@ -318,6 +319,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_LOG, 1}, "1.14.0"},
{{BuiltinOperator_SQRT, 1}, "1.10.0"},
{{BuiltinOperator_RSQRT, 1}, "1.10.0"},
{{BuiltinOperator_RSQRT, 2}, kPendingReleaseVersion},
{{BuiltinOperator_SQUARE, 1}, "1.12.0"},
{{BuiltinOperator_ZEROS_LIKE, 1}, "1.12.0"},
{{BuiltinOperator_ABS, 1}, "1.13.0"},