diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 681c3fbfe18..5af67a2e122 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -46,7 +46,6 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig = #test-id,min-android-sdk-version # activations_test -QuantizedActivationsOpTest/Relu6Uint8 FloatActivationsOpTest/Softmax[13]D,29 QuantizedActivationsOpTest/Softmax[13]D.+nt8,29 FloatActivationsOpTest/Softmax\dD @@ -60,7 +59,7 @@ FloatActivationsOpTest/Elu,30 FloatActivationsOpTest/HardSwish QuantizedActivationsOpTest/HardSwish QuantizedActivationsOpTest/HardSwishBias -QuantizedActivationsOpTest/Relu* +QuantizedActivationsOpTest/Relu.+nt8 QuantizedActivationsOpTest/PRelu,29 QuantizedActivationsOpTest/PReluSameShapes,29 QuantizedActivationsOpTest/PReluInt8.+,30 diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 46fec3981b6..ac7aa3cdd1f 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -280,12 +280,18 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); - if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) { + if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 || + input->type == kTfLiteInt16) { double real_multiplier = input->params.scale / output->params.scale; QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift); } + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } + return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } @@ -740,10 +746,15 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { QuantizedReluX(0.0f, std::numeric_limits::infinity(), input, output, data); } break; + case kTfLiteInt16: { + QuantizedReluX(0.0f, std::numeric_limits::infinity(), + input, output, data); + } break; default: - TF_LITE_KERNEL_LOG( - context, "Only float32 & int8/uint8 is supported currently, got %s.", - TfLiteTypeGetName(input->type)); + TF_LITE_KERNEL_LOG(context, + "Only float32, uint8, int8 and int16 are supported " + "currently, got %s.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } return kTfLiteOk; @@ -857,11 +868,15 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { QuantizedReluX(0.0f, 6.0f, input, output, data); return kTfLiteOk; } break; + case kTfLiteInt16: { + QuantizedReluX(0.0f, 6.0f, input, output, data); + return kTfLiteOk; + } break; default: - TF_LITE_KERNEL_LOG( - context, - "Only float32, uint8 and int8 are supported currently, got %s.", - TfLiteTypeGetName(input->type)); + TF_LITE_KERNEL_LOG(context, + "Only float32, uint8, int8 and int16 are supported " + "currently, got %s.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index 447245d11fe..19d947e960d 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -664,6 +664,29 @@ TEST(QuantizedActivationsOpTest, Relu6Int8) { ElementsAreArray({0, 0, 32, 64, 48, 0, 96, 16})); } +TEST(QuantizedActivationsOpTest, Relu6Int16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_RELU6, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0, 0, 2, 4, // + 3, 0, 6, 1, // + }, + kQuantizedToleranceInt16))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 24576, 4096})); +} + TEST(QuantizedActivationsOpTest, ReluUint8) { const float kMin = -1; const float kMax = 127.f / 128.f; @@ -709,6 +732,29 @@ TEST(QuantizedActivationsOpTest, ReluInt8) { ElementsAreArray({0, 0, 32, 64, 48, 0, 112, 16})); } +TEST(QuantizedActivationsOpTest, ReluInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_RELU, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 7, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0, 0, 2, 4, // + 3, 0, 7, 1, // + }, + kQuantizedToleranceInt16))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 28672, 4096})); +} + TEST_P(TanhOpTest, TanhUint8) { const float kMin = -1; const float kMax = 127.f / 128.f; diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index a4e960dea81..997d6d86a4e 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -38,10 +38,10 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH()); AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version = */ 1, /* max_version = */ 3); AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(), diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index b9a5b13b477..1fdfd3a073d 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -193,10 +193,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF()); AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_TANH, Register_TANH_REF(), /* min_version = */ 1, /* max_version = */ 3); AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC_REF(), diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 6ec320c4144..bf50695754e 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -861,7 +861,6 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.inputs = {{0, {}}}; property.outputs = {{0, {}}}; property.version = 2; - property.quantizable_int16 = false; break; case BuiltinOperator_RELU_N1_TO_1: property.inputs = {{0, {}}}; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index aff9a3cbde2..49777026fc2 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -375,8 +375,17 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; - case BuiltinOperator_ABS: case BuiltinOperator_RELU: + if (op_sig.input_types.at(0) == TensorType_INT16) { + return 3; + } + if (op_sig.input_types.at(0) == TensorType_INT8 || + op_sig.input_types.at(0) == TensorType_UINT8) { + return 2; + } + return 1; + + case BuiltinOperator_ABS: if (op_sig.input_types.at(0) == TensorType_INT8 || op_sig.input_types.at(0) == TensorType_UINT8) { return 2; @@ -554,6 +563,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_MEAN: case BuiltinOperator_PAD: case BuiltinOperator_PADV2: + case BuiltinOperator_RELU6: // In case of int16 inputs, the version is 3. if (op_sig.input_types.at(0) == TensorType_INT16) { return 3; @@ -582,7 +592,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_SUM: case BuiltinOperator_REDUCE_MAX: case BuiltinOperator_REDUCE_MIN: - case BuiltinOperator_RELU6: case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_TOPK_V2: case BuiltinOperator_ARG_MAX: diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index f954ea6b6d2..aa12a7efc34 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -164,6 +164,12 @@ TEST(OpVersionTest, VersioningUnpackTest) { TEST(OpVersionTest, VersioningReluTest) { OpSignature fake_op_sig = { + .op = BuiltinOperator_RELU, + .input_types = std::vector{TensorType_INT16}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig = { .op = BuiltinOperator_RELU, .input_types = std::vector{TensorType_INT8}, }; @@ -356,7 +362,7 @@ TEST(OpVersionTest, VersioningSelectTest) { } TEST(OpVersionTest, VersioningRelu6Test) { - SimpleVersioningTest(BuiltinOperator_RELU6); + SimpleVersioningTestExtended(BuiltinOperator_RELU6); } TEST(OpVersionTest, VersioningFullyConnectedTest) { diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index d44289f3b09..10a4958e4af 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -188,6 +188,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_REDUCE_ANY, 1}, "1.11.0"}, {{BuiltinOperator_RELU6, 1}, "1.5.0"}, {{BuiltinOperator_RELU6, 2}, "1.14.0"}, + {{BuiltinOperator_RELU6, 3}, kPendingReleaseVersion}, {{BuiltinOperator_RESIZE_BILINEAR, 1}, "1.7.0"}, {{BuiltinOperator_RESIZE_BILINEAR, 2}, "1.14.0"}, {{BuiltinOperator_RESIZE_BILINEAR, 3}, "2.2.0"}, @@ -293,6 +294,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_ROUND, 1}, "1.14.0"}, {{BuiltinOperator_RELU, 1}, "1.5.0"}, {{BuiltinOperator_RELU, 2}, "2.1.0"}, + {{BuiltinOperator_RELU, 3}, kPendingReleaseVersion}, {{BuiltinOperator_RELU_N1_TO_1, 1}, "1.5.0"}, {{BuiltinOperator_PRELU, 1}, "1.8.0"}, {{BuiltinOperator_EXP, 1}, "1.7.0"},