Quantize abs uses same input and output scale for compatibility with mlir.
PiperOrigin-RevId: 355556020 Change-Id: I29a2cbc5f3cfb45721907354287b086703a0314b
This commit is contained in:
parent
a9e9b8aae2
commit
57db297c71
@ -531,6 +531,7 @@ def TFL_AbsOp : TFL_Op<"abs", [
|
|||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultShape,
|
SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultType,
|
SameOperandsAndResultType,
|
||||||
|
SameOperandsAndResultsScale,
|
||||||
NoQuantizableResult]> {
|
NoQuantizableResult]> {
|
||||||
let summary = "Absolute value operator";
|
let summary = "Absolute value operator";
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ struct OpData {
|
|||||||
int32_t shift;
|
int32_t shift;
|
||||||
int input_offset;
|
int input_offset;
|
||||||
int output_offset;
|
int output_offset;
|
||||||
|
bool needs_rescale;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool IsNumericSupportedType(const TfLiteType type) {
|
bool IsNumericSupportedType(const TfLiteType type) {
|
||||||
@ -118,7 +119,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
const float input_scale = input_params->scale->data[0];
|
const float input_scale = input_params->scale->data[0];
|
||||||
const float output_scale = output_params->scale->data[0];
|
const float output_scale = output_params->scale->data[0];
|
||||||
if (op_name == kAbsName) {
|
op_data->needs_rescale = input_scale != output_scale;
|
||||||
|
if (op_name == kAbsName && op_data->needs_rescale) {
|
||||||
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
|
||||||
&op_data->shift);
|
&op_data->shift);
|
||||||
} else if (op_name == kRsqrtName) {
|
} else if (op_name == kRsqrtName) {
|
||||||
@ -188,10 +190,13 @@ TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
|
|
||||||
std::function<T(T)> func = [&](T i) {
|
std::function<T(T)> func = [&](T i) {
|
||||||
const int32_t value = std::abs(i - op_data->input_offset);
|
const int32_t value = std::abs(i - op_data->input_offset);
|
||||||
|
if (!op_data->needs_rescale) {
|
||||||
|
return static_cast<T>(
|
||||||
|
std::min(std::max(value + op_data->output_offset, kMin), kMax));
|
||||||
|
}
|
||||||
const int32_t output = MultiplyByQuantizedMultiplier(
|
const int32_t output = MultiplyByQuantizedMultiplier(
|
||||||
value, op_data->multiplier, op_data->shift) +
|
value, op_data->multiplier, op_data->shift) +
|
||||||
op_data->output_offset;
|
op_data->output_offset;
|
||||||
|
|
||||||
return static_cast<T>(std::min(std::max(output, kMin), kMax));
|
return static_cast<T>(std::min(std::max(output, kMin), kMax));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -191,6 +191,34 @@ TEST(ElementWise, AbsInt8) {
|
|||||||
ElementsAreArray(ArrayFloatNear(abs_data, kInputScale)));
|
ElementsAreArray(ArrayFloatNear(abs_data, kInputScale)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ElementWise, AbsSameScaleInt8) {
|
||||||
|
std::vector<float> data = {15., 46., 78., -142., -1., -17., -49., 113.};
|
||||||
|
std::vector<float> abs_data(data.size());
|
||||||
|
for (int i = 0; i < abs_data.size(); i++) {
|
||||||
|
abs_data[i] = std::abs(data[i]);
|
||||||
|
}
|
||||||
|
const auto minmax = std::minmax_element(data.begin(), data.end());
|
||||||
|
const float abs_max = std::max(std::abs(*minmax.first), *minmax.second);
|
||||||
|
const float kInputScale = (*minmax.second - *minmax.first) / 255.0;
|
||||||
|
const int input_zero_point = 127 - *minmax.second;
|
||||||
|
ElementWiseOpQuantizedModel m(
|
||||||
|
BuiltinOperator_ABS,
|
||||||
|
{TensorType_INT8,
|
||||||
|
{1, 8},
|
||||||
|
*minmax.first,
|
||||||
|
*minmax.second,
|
||||||
|
kInputScale,
|
||||||
|
input_zero_point,
|
||||||
|
true,
|
||||||
|
{kInputScale},
|
||||||
|
{input_zero_point}},
|
||||||
|
{TensorType_INT8, {1, 8}, 0, abs_max, kInputScale, input_zero_point});
|
||||||
|
m.AsymmetricQuantizeAndPopulate<int8_t>(m.input(), data);
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractDequantVector<int8_t>(m.output()),
|
||||||
|
ElementsAreArray(ArrayFloatNear(abs_data, kInputScale)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ElementWise, AbsInt16) {
|
TEST(ElementWise, AbsInt16) {
|
||||||
const float kQuantizedTolerance = GetQuantizationStep<int16_t>(-150, 150);
|
const float kQuantizedTolerance = GetQuantizationStep<int16_t>(-150, 150);
|
||||||
std::vector<float> data = {15., 46., 78., -142., -1., -17., -49., 113.};
|
std::vector<float> data = {15., 46., 78., -142., -1., -17., -49., 113.};
|
||||||
|
@ -66,6 +66,11 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
|||||||
OperatorProperty property;
|
OperatorProperty property;
|
||||||
switch (op_code) {
|
switch (op_code) {
|
||||||
case BuiltinOperator_ABS:
|
case BuiltinOperator_ABS:
|
||||||
|
property.inputs = {{0, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.version = 2;
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
|
break;
|
||||||
case BuiltinOperator_RSQRT:
|
case BuiltinOperator_RSQRT:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user