Merge pull request #27094 from amitsrivastava78:q_range
PiperOrigin-RevId: 246038739
This commit is contained in:
commit
088c147444
@ -387,6 +387,32 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void QuantizedRelu1(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
ActivationParams params;
|
||||
int32 kMin = -1;
|
||||
int32 kMax = 1;
|
||||
params.activation_type = FusedActivationFunctionType::kRelu1;
|
||||
|
||||
// Relu1 has a min range of -1, we need to quantize this
|
||||
params.quantized_activation_min =
|
||||
std::max(static_cast<int32_t>(std::numeric_limits<T>::min()),
|
||||
output->params.zero_point +
|
||||
static_cast<int32>(roundf(kMin / output->params.scale)));
|
||||
|
||||
// Relu1 has a max range of 1, we need to quantize this
|
||||
params.quantized_activation_max =
|
||||
std::min(static_cast<int32_t>(std::numeric_limits<T>::max()),
|
||||
output->params.zero_point +
|
||||
static_cast<int32>(roundf(kMax / output->params.scale)));
|
||||
|
||||
// Reused the optimized function written for ReluX
|
||||
optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input),
|
||||
GetTensorShape(output), GetTensorData<T>(output));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
@ -397,9 +423,18 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteUInt8: {
|
||||
QuantizedRelu1<uint8_t>(input, output);
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteInt8: {
|
||||
QuantizedRelu1<int8_t>(input, output);
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32 is supported currently, got %s.",
|
||||
"Only float32, uint8, int8 supported "
|
||||
"currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -255,6 +255,53 @@ TEST(QuantizedActivationsOpTest, LeakyReluUint8) {
|
||||
112,
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, Relu1Int8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 1;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_RELU_N1_TO_1,
|
||||
/*input=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax},
|
||||
/*output=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax});
|
||||
|
||||
m.SetInput<int8_t>({
|
||||
0.0, -0.6, 0.2, -0.4, //
|
||||
0.3, -2.0, 1.1, -0.1, //
|
||||
});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.0, -0.6, 0.2, -0.4, //
|
||||
0.3, -1.0, 1.0, -0.1, //
|
||||
},
|
||||
kQuantizedTolerance)));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, Relu1UInt8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 1;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_RELU_N1_TO_1,
|
||||
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax},
|
||||
/*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax});
|
||||
|
||||
m.SetInput<uint8_t>({
|
||||
0.0, -0.6, 0.2, -0.4, //
|
||||
0.3, -2.0, 1.1, -0.1, //
|
||||
});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.0, -0.6, 0.2, -0.4, //
|
||||
0.3, -1.0, 1.0, -0.1, //
|
||||
},
|
||||
kQuantizedTolerance)));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, Relu6Int8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
|
Loading…
Reference in New Issue
Block a user