Merge pull request #43617 from Tessil:toupstream/16x8_relu_relu_6

PiperOrigin-RevId: 340486126
Change-Id: I9ef5fcdfedfb76e24ecb7447ecf5a0cb1cd49371
This commit is contained in:
TensorFlower Gardener 2020-11-03 11:41:51 -08:00
commit 14b2d17031
9 changed files with 94 additions and 18 deletions

View File

@ -46,7 +46,6 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
#test-id,min-android-sdk-version
# activations_test
QuantizedActivationsOpTest/Relu6Uint8
FloatActivationsOpTest/Softmax[13]D,29
QuantizedActivationsOpTest/Softmax[13]D.+nt8,29
FloatActivationsOpTest/Softmax\dD
@ -60,7 +59,7 @@ FloatActivationsOpTest/Elu,30
FloatActivationsOpTest/HardSwish
QuantizedActivationsOpTest/HardSwish
QuantizedActivationsOpTest/HardSwishBias
QuantizedActivationsOpTest/Relu*
QuantizedActivationsOpTest/Relu.+nt8
QuantizedActivationsOpTest/PRelu,29
QuantizedActivationsOpTest/PReluSameShapes,29
QuantizedActivationsOpTest/PReluInt8.+,30

View File

@ -280,12 +280,18 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt16) {
double real_multiplier = input->params.scale / output->params.scale;
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
}
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
@ -740,10 +746,15 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
QuantizedReluX<int8_t>(0.0f, std::numeric_limits<float>::infinity(),
input, output, data);
} break;
case kTfLiteInt16: {
QuantizedReluX<int16_t>(0.0f, std::numeric_limits<float>::infinity(),
input, output, data);
} break;
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 & int8/uint8 is supported currently, got %s.",
TfLiteTypeGetName(input->type));
TF_LITE_KERNEL_LOG(context,
"Only float32, uint8, int8 and int16 are supported "
"currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
@ -857,11 +868,15 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
QuantizedReluX<int8_t>(0.0f, 6.0f, input, output, data);
return kTfLiteOk;
} break;
case kTfLiteInt16: {
QuantizedReluX<int16_t>(0.0f, 6.0f, input, output, data);
return kTfLiteOk;
} break;
default:
TF_LITE_KERNEL_LOG(
context,
"Only float32, uint8 and int8 are supported currently, got %s.",
TfLiteTypeGetName(input->type));
TF_LITE_KERNEL_LOG(context,
"Only float32, uint8, int8 and int16 are supported "
"currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@ -664,6 +664,29 @@ TEST(QuantizedActivationsOpTest, Relu6Int8) {
ElementsAreArray({0, 0, 32, 64, 48, 0, 96, 16}));
}
TEST(QuantizedActivationsOpTest, Relu6Int16) {
const float kMin = -1;
const float kMax = 32767.f / 32768.f;
QuantizedActivationsOpModel m(
BuiltinOperator_RELU6,
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
m.SetInput<int16_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear(
{
0, 0, 2, 4, //
3, 0, 6, 1, //
},
kQuantizedToleranceInt16)));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 24576, 4096}));
}
TEST(QuantizedActivationsOpTest, ReluUint8) {
const float kMin = -1;
const float kMax = 127.f / 128.f;
@ -709,6 +732,29 @@ TEST(QuantizedActivationsOpTest, ReluInt8) {
ElementsAreArray({0, 0, 32, 64, 48, 0, 112, 16}));
}
TEST(QuantizedActivationsOpTest, ReluInt16) {
const float kMin = -1;
const float kMax = 32767.f / 32768.f;
QuantizedActivationsOpModel m(
BuiltinOperator_RELU,
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
m.SetInput<int16_t>({
0, -6, 2, 4, //
3, -2, 7, 1, //
});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear(
{
0, 0, 2, 4, //
3, 0, 7, 1, //
},
kQuantizedToleranceInt16)));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 28672, 4096}));
}
TEST_P(TanhOpTest, TanhUint8) {
const float kMin = -1;
const float kMax = 127.f / 128.f;

View File

@ -38,10 +38,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(),

View File

@ -193,10 +193,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_TANH, Register_TANH_REF(), /* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC_REF(),

View File

@ -861,7 +861,6 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.version = 2;
property.quantizable_int16 = false;
break;
case BuiltinOperator_RELU_N1_TO_1:
property.inputs = {{0, {}}};

View File

@ -375,8 +375,17 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_ABS:
case BuiltinOperator_RELU:
if (op_sig.input_types.at(0) == TensorType_INT16) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8 ||
op_sig.input_types.at(0) == TensorType_UINT8) {
return 2;
}
return 1;
case BuiltinOperator_ABS:
if (op_sig.input_types.at(0) == TensorType_INT8 ||
op_sig.input_types.at(0) == TensorType_UINT8) {
return 2;
@ -554,6 +563,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_MEAN:
case BuiltinOperator_PAD:
case BuiltinOperator_PADV2:
case BuiltinOperator_RELU6:
// In case of int16 inputs, the version is 3.
if (op_sig.input_types.at(0) == TensorType_INT16) {
return 3;
@ -582,7 +592,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_SUM:
case BuiltinOperator_REDUCE_MAX:
case BuiltinOperator_REDUCE_MIN:
case BuiltinOperator_RELU6:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_ARG_MAX:

View File

@ -164,6 +164,12 @@ TEST(OpVersionTest, VersioningUnpackTest) {
TEST(OpVersionTest, VersioningReluTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_RELU,
.input_types = std::vector<TensorType>{TensorType_INT16},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_RELU,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
@ -356,7 +362,7 @@ TEST(OpVersionTest, VersioningSelectTest) {
}
TEST(OpVersionTest, VersioningRelu6Test) {
SimpleVersioningTest(BuiltinOperator_RELU6);
SimpleVersioningTestExtended(BuiltinOperator_RELU6);
}
TEST(OpVersionTest, VersioningFullyConnectedTest) {

View File

@ -188,6 +188,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_REDUCE_ANY, 1}, "1.11.0"},
{{BuiltinOperator_RELU6, 1}, "1.5.0"},
{{BuiltinOperator_RELU6, 2}, "1.14.0"},
{{BuiltinOperator_RELU6, 3}, kPendingReleaseVersion},
{{BuiltinOperator_RESIZE_BILINEAR, 1}, "1.7.0"},
{{BuiltinOperator_RESIZE_BILINEAR, 2}, "1.14.0"},
{{BuiltinOperator_RESIZE_BILINEAR, 3}, "2.2.0"},
@ -293,6 +294,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_ROUND, 1}, "1.14.0"},
{{BuiltinOperator_RELU, 1}, "1.5.0"},
{{BuiltinOperator_RELU, 2}, "2.1.0"},
{{BuiltinOperator_RELU, 3}, kPendingReleaseVersion},
{{BuiltinOperator_RELU_N1_TO_1, 1}, "1.5.0"},
{{BuiltinOperator_PRELU, 1}, "1.8.0"},
{{BuiltinOperator_EXP, 1}, "1.7.0"},