diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 47b531b71b2..390ae0783c4 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -117,6 +117,32 @@ bool IsFloatOrUint8Operator(const TfLiteContext* context, return IsFloatOrUInt8(input_type); } +// Check if the operation requires explict conversion from int8 to uint8 values. +bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, + const TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinDepthwiseConv2d: + case kTfLiteBuiltinFullyConnected: + case kTfLiteBuiltinL2Normalization: { + const int input_id = node->inputs->data[0]; + const TfLiteType input_type = context->tensors[input_id].type; + if (input_type == kTfLiteInt8) { + const int weights_id = node->inputs->data[1]; + const auto& weights_tensor = context->tensors[weights_id]; + if ((weights_tensor.type == kTfLiteInt8 || + weights_tensor.type == kTfLiteUInt8) && + weights_tensor.quantization.type == kTfLiteAffineQuantization) { + return true; + } + } + return false; + } + default: + return false; + } +} + bool IsHybridOperator(const TfLiteContext* context, int builtin_code, const TfLiteNode* node) { switch (builtin_code) { @@ -228,6 +254,12 @@ bool HasZeroes(TfLiteIntArrayView array) { return false; } +// Bit mask for tensor flags. +enum { + NN_TENSOR_FLAG_SCALAR_AS_TENSOR = 1U << 0, + NN_TENSOR_FLAG_INT8_CONVERSION = 1U << 1, +}; + } // namespace // RAII NN API Model Destructor for use with std::unique_ptr @@ -428,13 +460,13 @@ class NNAPIOpBuilder { } TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op, - bool scalar_as_tensor = false) { - return AddTensor(tensor_index, hybrid_op, &augmented_inputs_, - scalar_as_tensor); + int tensor_flags = 0) { + return AddTensor(tensor_index, hybrid_op, &augmented_inputs_, tensor_flags); } - TfLiteStatus AddTensorOutput(int tensor_index) { - return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_); + TfLiteStatus AddTensorOutput(int tensor_index, int tensor_flags = 0) { + return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_, + tensor_flags); } TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) { @@ -464,7 +496,8 @@ class NNAPIOpBuilder { // Dequantize operation is added, yielding a new tensor. const TfLiteTensor& tensor = context_->tensors[lite_index]; ANeuralNetworksOperandType operand_type{ - dequantized_type, static_cast(tensor.dims->size), + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor.dims->size), reinterpret_cast(tensor.dims->data), 0.f, 0}; RETURN_TFLITE_ERROR_IF_NN_ERROR( context_, @@ -608,8 +641,11 @@ class NNAPIOpBuilder { // If another caller previously created a NN API tensor for `tensor_index` // then the existing one is returned. TfLiteStatus AddTensor(int tensor_index, bool hybrid_op, - std::vector* indices, - bool scalar_as_tensor = false) { + std::vector* indices, int tensor_flags = 0) { + const bool scalar_as_tensor = + tensor_flags & NN_TENSOR_FLAG_SCALAR_AS_TENSOR; + const bool need_int8_conversion = + tensor_flags & NN_TENSOR_FLAG_INT8_CONVERSION; int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); if (ann_tensor_index != -1) { indices->push_back(ann_tensor_index); @@ -640,11 +676,17 @@ class NNAPIOpBuilder { break; case kTfLiteUInt8: case kTfLiteInt8: - nn_type = (tensor_type == kTfLiteUInt8) + // If explicit int8 conversion is needed, we still need + // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM type. + nn_type = (tensor_type == kTfLiteUInt8 || need_int8_conversion) ? ANEURALNETWORKS_TENSOR_QUANT8_ASYMM : ANEURALNETWORKS_TENSOR_QUANT8_SYMM; scale = tensor->params.scale; zeroPoint = tensor->params.zero_point; + if (need_int8_conversion) { + zeroPoint += 128; + operand_mapping_->add_type_conversion(tensor_index, kTfLiteUInt8); + } if (scale == 0) { // TENSOR_QUANT8_ASYMM and ANEURALNETWORKS_TENSOR_QUANT8_ASYMM // with zero scale are not valid in NNAPI. @@ -672,6 +714,26 @@ class NNAPIOpBuilder { tensor_rank = 1; tensor_dims = &tensor_rank; } + ANeuralNetworksSymmPerChannelQuantParams ann_perchannel_params; + if (tensor_type == kTfLiteInt8 || tensor_type == kTfLiteUInt8) { + if (tensor->quantization.type == kTfLiteAffineQuantization) { + TfLiteAffineQuantization* quantization_params = + static_cast(tensor->quantization.params); + if (quantization_params->scale->size > 1) { + // Set up per-channel quantization. + ann_perchannel_params = { + .channelDim = static_cast( + quantization_params->quantized_dimension), + .scaleCount = + static_cast(quantization_params->scale->size), + .scales = quantization_params->scale->data, + }; + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL; + scale = 0.0f; + zeroPoint = 0; + } + } + } ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims, scale, zeroPoint}; @@ -679,6 +741,12 @@ class NNAPIOpBuilder { context_, nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + if (nn_type == ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL) { + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, + nnapi_->ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( + nn_model_, ann_tensor_index, &ann_perchannel_params)); + } if (tensor->allocation_type == kTfLiteMmapRo) { // TODO(b/80630405): Use NNAPIAllocation. RETURN_TFLITE_ERROR_IF_NN_ERROR( @@ -816,13 +884,24 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinConv2d: if (version <= 2) { - // TODO(b/134571155): Fix hybrid and per-channel support in API 29. - if ((android_sdk_version <= kMinSdkVersionForNNAPI12) && + if ((android_sdk_version < kMinSdkVersionForNNAPI12) && (IsHybridOperator(context, builtin_code, node) || !IsFloatOrUint8Operator(context, node))) { // Hybrid operators not supported before NNAPI 1.2. return nullptr; } + if (android_sdk_version < kMinSdkVersionForNNAPI12) { + // Per-channel quantized convolution not supported before NNAPI 1.2. + const auto& filter_tensor = context->tensors[node->inputs->data[1]]; + if (filter_tensor.quantization.type == kTfLiteAffineQuantization) { + TfLiteAffineQuantization* quantization_params = + static_cast( + filter_tensor.quantization.params); + if (quantization_params->scale->size > 1) { + return nullptr; + } + } + } const auto input_type = context->tensors[node->inputs->data[0]].type; if (android_sdk_version < kMinSdkVersionForNNAPI12 && input_type == kTfLiteUInt8 && @@ -874,8 +953,7 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinDepthwiseConv2d: if (version == 1) { - // TODO(b/134571155): Fix per-channel support in API 29. - if (android_sdk_version <= kMinSdkVersionForNNAPI12 && + if (android_sdk_version < kMinSdkVersionForNNAPI12 && !IsFloatOrUint8Operator(context, node)) { return nullptr; } @@ -927,8 +1005,7 @@ class NNAPIDelegateKernel { if (output_type == kTfLiteInt16) { return nullptr; } - // TODO(b/134571155): Fix hybrid and per-channel support in API 29. - if (android_sdk_version <= kMinSdkVersionForNNAPI12 && + if (android_sdk_version < kMinSdkVersionForNNAPI12 && (IsHybridOperator(context, builtin_code, node) || !IsFloatOrUint8Operator(context, node))) { // Hybrid operators not supported before NNAPI 1.2. @@ -1055,9 +1132,6 @@ class NNAPIDelegateKernel { !IsFloatOperator(context, node)) { return nullptr; } - if (!IsFloatOrUint8Operator(context, node)) { - return nullptr; - } auto builtin = reinterpret_cast(node->builtin_data); if (builtin->activation == kTfLiteActNone) { @@ -1224,9 +1298,9 @@ class NNAPIDelegateKernel { (context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 || android_sdk_version >= kMinSdkVersionForNNAPI12)) { // NNAPI does not support specifying the padding value. - // Before 1.2, NNAPI pads physical zero for quantized tensors, - // so only delegate float pad to NNAPI. NNAPI 1.2 onwards pads - // with zero-point, so delegate quantized pad as well. + // Before 1.2, NNAPI pads physical zero for quantized tensors, so + // only delegate float pad to NNAPI. NNAPI 1.2 onwards pads with + // zero-point, so delegate quantized pad as well. return BasicMappingFn; } else if (node->inputs->size == 3 && android_sdk_version >= kMinSdkVersionForNNAPI12) { @@ -1822,6 +1896,15 @@ class NNAPIDelegateKernel { input_offset)[i] = static_cast(tensor->data.raw_const[i]); } + } else if (tensor->type == kTfLiteInt8 && + ann_type_equivalent == kTfLiteUInt8) { + // Explicitly convert int8 values to uint8 values. + uint8_t* input_ptr = reinterpret_cast( + nn_input_memory_->get_data_ptr() + input_offset); + for (int i = 0; i < NumElements(tensor); ++i) { + input_ptr[i] = static_cast( + static_cast(tensor->data.int8[i]) + 128); + } } else { context->ReportError( context, @@ -1917,6 +2000,17 @@ class NNAPIDelegateKernel { if (tensor->buffer_handle != kTfLiteNullBufferHandle) { continue; } + TfLiteType ann_type_equivalent = + operand_mapping_.lite_index_to_ann_type_conversion(output_index); + if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) { + // Explicitly convert uint8 values to int8 values. + uint8_t* output_ptr = reinterpret_cast( + nn_output_memory_->get_data_ptr() + output_offset); + for (int i = 0; i < NumElements(tensor); ++i) { + output_ptr[i] = + static_cast(static_cast(output_ptr[i]) - 128); + } + } memcpy(tensor->data.raw, nn_output_memory_->get_data_ptr() + output_offset, tensor->bytes); output_offset += tensor->bytes; @@ -1996,7 +2090,7 @@ class NNAPIDelegateKernel { const TfLiteType type = context->tensors[tensor_id].type; // Nothing to do for this tensor if it's not quantized. - if (type != kTfLiteUInt8) continue; + if (!IsQuantized(type)) continue; // Insert Dequantize operator if it hasn't been done already and change // the node's input accordingly. @@ -2020,10 +2114,25 @@ class NNAPIDelegateKernel { const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node); const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code); + const bool need_int8_conversion = + NeedInt8Conversion(context, reg->builtin_code, node); + int input_tensor_flags = 0; + if (scalar_as_tensor) { + input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR; + } // Map inputs to NN API tensor indices. for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) { const auto input_index = node->inputs->data[input_pos]; + if (need_int8_conversion && + (input_pos == 0 || + reg->builtin_code == kTfLiteBuiltinFullyConnected)) { + // Only selected inputs require int8 conversion. + TF_LITE_ENSURE_STATUS(builder.AddTensorInput( + input_index, hybrid_op, + input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION)); + continue; + } if (reg->builtin_code == kTfLiteBuiltinLstm && input_pos >= 20) { // Skip layer normalization weights. They are added in the Map // function (after all the other inputs added there) since layer @@ -2031,8 +2140,8 @@ class NNAPIDelegateKernel { // NNAPI. continue; } - // Pad and Padv2 have an optional parameter for a pad value which has to - // be converted to a scalar type in NN API. + // Pad and Padv2 have an optional parameter for a pad value which has + // to be converted to a scalar type in NN API. if ((reg->builtin_code == kTfLiteBuiltinPadv2 || reg->builtin_code == kTfLiteBuiltinPad) && node->inputs->size == 3 && input_pos == 2) { @@ -2080,10 +2189,10 @@ class NNAPIDelegateKernel { TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); } else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear) { if (input_pos == 0) { - // Only the first input tensor is added. The second one, specifying - // the output height and width, is not added and instead the height - // and width will be added individually as scalars by the mapping - // function returned by Map(). + // Only the first input tensor is added. The second one, + // specifying the output height and width, is not added and + // instead the height and width will be added individually as + // scalars by the mapping function returned by Map(). TF_LITE_ENSURE_STATUS( builder.AddTensorInput(input_index, hybrid_op)); } @@ -2097,8 +2206,8 @@ class NNAPIDelegateKernel { // have different order. continue; } else { - TF_LITE_ENSURE_STATUS( - builder.AddTensorInput(input_index, hybrid_op, scalar_as_tensor)); + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op, + input_tensor_flags)); } } // Get op type and operands @@ -2107,8 +2216,13 @@ class NNAPIDelegateKernel { node)({context, &builder, node, &model_state_outputs_, &model_state_tfl_inputs_}); // Map outputs to NN API tensor indices. + int output_tensor_flags = 0; + if (need_int8_conversion) { + output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION; + } for (auto output_index : TfLiteIntArrayView(node->outputs)) { - TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + TF_LITE_ENSURE_STATUS( + builder.AddTensorOutput(output_index, output_tensor_flags)); } // Dequantize operators may have to be added in case inputs are to be diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index 835c274933b..12ed8e088f5 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -971,7 +971,8 @@ class HybridConvolutionOpModel : public BaseConvolutionOpModel { TEST_P(ConvolutionOpTest, SimpleTestHybridUint8) { HybridConvolutionOpModel m( GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, - {TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + {TensorType_UINT8, {3, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0}, + {TensorType_FLOAT32, {}}); m.SetInput({ // First batch @@ -1030,7 +1031,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridUint8) { TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsUint8) { HybridConvolutionOpModel m( GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, - {TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}}); + {TensorType_UINT8, {3, 2, 2, 2}, 0, 0, 4.0 / 127.0, 0}, + {TensorType_FLOAT32, {}}); m.SetInput({ // First batch @@ -1062,7 +1064,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsUint8) { TEST_P(ConvolutionOpTest, PointwiseHybridUint8) { HybridConvolutionOpModel m( GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, - {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1); + {TensorType_UINT8, {1, 1, 1, 2}, 0, 0, 2.0 / 127.0, 0}, + {TensorType_FLOAT32, {}}, 1, 1); m.SetInput({ // First batch @@ -1104,7 +1107,8 @@ TEST_P(ConvolutionOpTest, PointwiseHybridUint8) { TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) { HybridConvolutionOpModel m( GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, - {TensorType_INT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + {TensorType_INT8, {3, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0}, + {TensorType_FLOAT32, {}}); m.SetInput({ // First batch @@ -1161,9 +1165,20 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) { // // 2 * (A/2) * B = A * B, where the left side is this new test. TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsInt8) { - HybridConvolutionOpModel m( - GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, - {TensorType_INT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}}); + HybridConvolutionOpModel m(GetRegistration(), + {TensorType_FLOAT32, {2, 2, 4, 2}}, + {TensorType_INT8, + {3, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/ + {4.0 / 127.0, 4.0 / 127.0, 4.0 / 127.0}, + /*per_channel_quantization_offsets=*/{0, 0, 0}, + /*channel_index=*/0}, + {TensorType_FLOAT32, {}}); m.SetInput({ // First batch @@ -1195,7 +1210,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsInt8) { TEST_P(ConvolutionOpTest, PointwiseHybridInt8) { HybridConvolutionOpModel m( GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, - {TensorType_INT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1); + {TensorType_INT8, {1, 1, 1, 2}, 0, 0, 2.0 / 127.0, 0}, + {TensorType_FLOAT32, {}}, 1, 1); m.SetInput({ // First batch diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 83d119767e7..637ee6b2736 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -708,7 +708,8 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedUint8) { HybridFullyConnectedOpModel m( /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {2, 10}}, - /*weights=*/{TensorType_UINT8, {3, 10}, -63.5, 64}); // Hybrid + /*weights=*/ + {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 @@ -736,7 +737,7 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) { HybridFullyConnectedOpModel m( /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {2, 10}}, - /*weights=*/{TensorType_INT8, {3, 10}, -63.5, 64}); // Hybrid + /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid m.SetSignedWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index 1199c571d71..fe4db24543c 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -41,6 +41,7 @@ enum { ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, ANEURALNETWORKS_BOOL = 6, ANEURALNETWORKS_TENSOR_BOOL8 = 9, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11, ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13, };