Merge pull request #43617 from Tessil:toupstream/16x8_relu_relu_6
PiperOrigin-RevId: 340486126 Change-Id: I9ef5fcdfedfb76e24ecb7447ecf5a0cb1cd49371
This commit is contained in:
commit
14b2d17031
@ -46,7 +46,6 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
||||
#test-id,min-android-sdk-version
|
||||
|
||||
# activations_test
|
||||
QuantizedActivationsOpTest/Relu6Uint8
|
||||
FloatActivationsOpTest/Softmax[13]D,29
|
||||
QuantizedActivationsOpTest/Softmax[13]D.+nt8,29
|
||||
FloatActivationsOpTest/Softmax\dD
|
||||
@ -60,7 +59,7 @@ FloatActivationsOpTest/Elu,30
|
||||
FloatActivationsOpTest/HardSwish
|
||||
QuantizedActivationsOpTest/HardSwish
|
||||
QuantizedActivationsOpTest/HardSwishBias
|
||||
QuantizedActivationsOpTest/Relu*
|
||||
QuantizedActivationsOpTest/Relu.+nt8
|
||||
QuantizedActivationsOpTest/PRelu,29
|
||||
QuantizedActivationsOpTest/PReluSameShapes,29
|
||||
QuantizedActivationsOpTest/PReluInt8.+,30
|
||||
|
@ -280,12 +280,18 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
double real_multiplier = input->params.scale / output->params.scale;
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
|
||||
&data->output_shift);
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, output,
|
||||
TfLiteIntArrayCopy(input->dims));
|
||||
}
|
||||
@ -740,10 +746,15 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
QuantizedReluX<int8_t>(0.0f, std::numeric_limits<float>::infinity(),
|
||||
input, output, data);
|
||||
} break;
|
||||
case kTfLiteInt16: {
|
||||
QuantizedReluX<int16_t>(0.0f, std::numeric_limits<float>::infinity(),
|
||||
input, output, data);
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Only float32 & int8/uint8 is supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only float32, uint8, int8 and int16 are supported "
|
||||
"currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@ -857,11 +868,15 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
QuantizedReluX<int8_t>(0.0f, 6.0f, input, output, data);
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteInt16: {
|
||||
QuantizedReluX<int16_t>(0.0f, 6.0f, input, output, data);
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Only float32, uint8 and int8 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Only float32, uint8, int8 and int16 are supported "
|
||||
"currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
@ -664,6 +664,29 @@ TEST(QuantizedActivationsOpTest, Relu6Int8) {
|
||||
ElementsAreArray({0, 0, 32, 64, 48, 0, 96, 16}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, Relu6Int16) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 32767.f / 32768.f;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_RELU6,
|
||||
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
|
||||
m.SetInput<int16_t>({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 10, 1, //
|
||||
});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0, 0, 2, 4, //
|
||||
3, 0, 6, 1, //
|
||||
},
|
||||
kQuantizedToleranceInt16)));
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||
ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 24576, 4096}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, ReluUint8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
@ -709,6 +732,29 @@ TEST(QuantizedActivationsOpTest, ReluInt8) {
|
||||
ElementsAreArray({0, 0, 32, 64, 48, 0, 112, 16}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, ReluInt16) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 32767.f / 32768.f;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_RELU,
|
||||
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
|
||||
m.SetInput<int16_t>({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 7, 1, //
|
||||
});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0, 0, 2, 4, //
|
||||
3, 0, 7, 1, //
|
||||
},
|
||||
kQuantizedToleranceInt16)));
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||
ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 28672, 4096}));
|
||||
}
|
||||
|
||||
TEST_P(TanhOpTest, TanhUint8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
|
@ -38,10 +38,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
|
||||
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
|
||||
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(),
|
||||
|
@ -193,10 +193,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF());
|
||||
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
|
||||
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_TANH, Register_TANH_REF(), /* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC_REF(),
|
||||
|
@ -861,7 +861,6 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_RELU_N1_TO_1:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -375,8 +375,17 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_ABS:
|
||||
case BuiltinOperator_RELU:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(0) == TensorType_UINT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_ABS:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(0) == TensorType_UINT8) {
|
||||
return 2;
|
||||
@ -554,6 +563,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
case BuiltinOperator_MEAN:
|
||||
case BuiltinOperator_PAD:
|
||||
case BuiltinOperator_PADV2:
|
||||
case BuiltinOperator_RELU6:
|
||||
// In case of int16 inputs, the version is 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 3;
|
||||
@ -582,7 +592,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
case BuiltinOperator_SUM:
|
||||
case BuiltinOperator_REDUCE_MAX:
|
||||
case BuiltinOperator_REDUCE_MIN:
|
||||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
|
@ -164,6 +164,12 @@ TEST(OpVersionTest, VersioningUnpackTest) {
|
||||
|
||||
TEST(OpVersionTest, VersioningReluTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_RELU,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_RELU,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
@ -356,7 +362,7 @@ TEST(OpVersionTest, VersioningSelectTest) {
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningRelu6Test) {
|
||||
SimpleVersioningTest(BuiltinOperator_RELU6);
|
||||
SimpleVersioningTestExtended(BuiltinOperator_RELU6);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningFullyConnectedTest) {
|
||||
|
@ -188,6 +188,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_REDUCE_ANY, 1}, "1.11.0"},
|
||||
{{BuiltinOperator_RELU6, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_RELU6, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_RELU6, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_RESIZE_BILINEAR, 1}, "1.7.0"},
|
||||
{{BuiltinOperator_RESIZE_BILINEAR, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_RESIZE_BILINEAR, 3}, "2.2.0"},
|
||||
@ -293,6 +294,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_ROUND, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_RELU, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_RELU, 2}, "2.1.0"},
|
||||
{{BuiltinOperator_RELU, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_RELU_N1_TO_1, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_PRELU, 1}, "1.8.0"},
|
||||
{{BuiltinOperator_EXP, 1}, "1.7.0"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user