Added 8-bit Quantization support for RELU_N1_TO_1.
Added the quantization support for the operator.
This commit is contained in:
parent
266c5ab878
commit
4c22d494d0
@ -387,6 +387,32 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
void QuantizedRelu1(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||||
|
ActivationParams params;
|
||||||
|
int32 kMin = -1;
|
||||||
|
int32 kMax = 1;
|
||||||
|
params.activation_type = FusedActivationFunctionType::kRelu1;
|
||||||
|
|
||||||
|
// Relu1 has a min range of -1, we need to quantize this
|
||||||
|
params.quantized_activation_min =
|
||||||
|
std::max(static_cast<int32_t>(std::numeric_limits<T>::min()),
|
||||||
|
output->params.zero_point +
|
||||||
|
static_cast<int32>(roundf(kMin / output->params.scale)));
|
||||||
|
|
||||||
|
// Relu1 has a max range of 1, we need to quantize this
|
||||||
|
params.quantized_activation_max =
|
||||||
|
std::min(static_cast<int32_t>(std::numeric_limits<T>::max()),
|
||||||
|
output->params.zero_point +
|
||||||
|
static_cast<int32>(roundf(kMax / output->params.scale)));
|
||||||
|
|
||||||
|
// Reused the optimized function written for ReluX
|
||||||
|
optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<T>(output));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
@ -397,9 +423,18 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTensorData<float>(output));
|
GetTensorData<float>(output));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
} break;
|
} break;
|
||||||
|
case kTfLiteUInt8: {
|
||||||
|
QuantizedRelu1<uint8_t>(input, output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
} break;
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
QuantizedRelu1<int8_t>(input, output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
"Only float32 is supported currently, got %s.",
|
"Only float32, uint8, int8 supported "
|
||||||
|
"currently, got %s.",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
@ -255,6 +255,53 @@ TEST(QuantizedActivationsOpTest, LeakyReluUint8) {
|
|||||||
112,
|
112,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedActivationsOpTest, Relu1Int8) {
|
||||||
|
const float kMin = -1;
|
||||||
|
const float kMax = 1;
|
||||||
|
QuantizedActivationsOpModel m(
|
||||||
|
BuiltinOperator_RELU_N1_TO_1,
|
||||||
|
/*input=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax},
|
||||||
|
/*output=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax});
|
||||||
|
|
||||||
|
m.SetInput<int8_t>({
|
||||||
|
0.0, -0.6, 0.2, -0.4, //
|
||||||
|
0.3, -2.0, 1.1, -0.1, //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
0.0, -0.6, 0.2, -0.4, //
|
||||||
|
0.3, -1.0, 1.0, -0.1, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedActivationsOpTest, Relu1UInt8) {
|
||||||
|
const float kMin = -1;
|
||||||
|
const float kMax = 1;
|
||||||
|
QuantizedActivationsOpModel m(
|
||||||
|
BuiltinOperator_RELU_N1_TO_1,
|
||||||
|
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax},
|
||||||
|
/*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax});
|
||||||
|
|
||||||
|
m.SetInput<uint8_t>({
|
||||||
|
0.0, -0.6, 0.2, -0.4, //
|
||||||
|
0.3, -2.0, 1.1, -0.1, //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
0.0, -0.6, 0.2, -0.4, //
|
||||||
|
0.3, -1.0, 1.0, -0.1, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(QuantizedActivationsOpTest, Relu6Int8) {
|
TEST(QuantizedActivationsOpTest, Relu6Int8) {
|
||||||
const float kMin = -1;
|
const float kMin = -1;
|
||||||
const float kMax = 127.f / 128.f;
|
const float kMax = 127.f / 128.f;
|
||||||
|
Loading…
Reference in New Issue
Block a user