Add proper support for per-channel quantized operations and int8 input / output tensors.

Also Update the tests to use the correct quantization parameters.

PiperOrigin-RevId: 255265635
This commit is contained in:
A. Unique TensorFlower 2019-06-26 14:32:47 -07:00 committed by TensorFlower Gardener
parent 45eb6337c6
commit ff9bb17674
4 changed files with 173 additions and 41 deletions

View File

@ -117,6 +117,32 @@ bool IsFloatOrUint8Operator(const TfLiteContext* context,
return IsFloatOrUInt8(input_type);
}
// Check if the operation requires explict conversion from int8 to uint8 values.
bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
const TfLiteNode* node) {
switch (builtin_code) {
case kTfLiteBuiltinConv2d:
case kTfLiteBuiltinDepthwiseConv2d:
case kTfLiteBuiltinFullyConnected:
case kTfLiteBuiltinL2Normalization: {
const int input_id = node->inputs->data[0];
const TfLiteType input_type = context->tensors[input_id].type;
if (input_type == kTfLiteInt8) {
const int weights_id = node->inputs->data[1];
const auto& weights_tensor = context->tensors[weights_id];
if ((weights_tensor.type == kTfLiteInt8 ||
weights_tensor.type == kTfLiteUInt8) &&
weights_tensor.quantization.type == kTfLiteAffineQuantization) {
return true;
}
}
return false;
}
default:
return false;
}
}
bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
const TfLiteNode* node) {
switch (builtin_code) {
@ -228,6 +254,12 @@ bool HasZeroes(TfLiteIntArrayView array) {
return false;
}
// Bit mask for tensor flags.
enum {
NN_TENSOR_FLAG_SCALAR_AS_TENSOR = 1U << 0,
NN_TENSOR_FLAG_INT8_CONVERSION = 1U << 1,
};
} // namespace
// RAII NN API Model Destructor for use with std::unique_ptr
@ -428,13 +460,13 @@ class NNAPIOpBuilder {
}
TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op,
bool scalar_as_tensor = false) {
return AddTensor(tensor_index, hybrid_op, &augmented_inputs_,
scalar_as_tensor);
int tensor_flags = 0) {
return AddTensor(tensor_index, hybrid_op, &augmented_inputs_, tensor_flags);
}
TfLiteStatus AddTensorOutput(int tensor_index) {
return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_);
TfLiteStatus AddTensorOutput(int tensor_index, int tensor_flags = 0) {
return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_,
tensor_flags);
}
TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) {
@ -464,7 +496,8 @@ class NNAPIOpBuilder {
// Dequantize operation is added, yielding a new tensor.
const TfLiteTensor& tensor = context_->tensors[lite_index];
ANeuralNetworksOperandType operand_type{
dequantized_type, static_cast<uint32_t>(tensor.dims->size),
ANEURALNETWORKS_TENSOR_FLOAT32,
static_cast<uint32_t>(tensor.dims->size),
reinterpret_cast<uint32_t*>(tensor.dims->data), 0.f, 0};
RETURN_TFLITE_ERROR_IF_NN_ERROR(
context_,
@ -608,8 +641,11 @@ class NNAPIOpBuilder {
// If another caller previously created a NN API tensor for `tensor_index`
// then the existing one is returned.
TfLiteStatus AddTensor(int tensor_index, bool hybrid_op,
std::vector<uint32_t>* indices,
bool scalar_as_tensor = false) {
std::vector<uint32_t>* indices, int tensor_flags = 0) {
const bool scalar_as_tensor =
tensor_flags & NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
const bool need_int8_conversion =
tensor_flags & NN_TENSOR_FLAG_INT8_CONVERSION;
int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index);
if (ann_tensor_index != -1) {
indices->push_back(ann_tensor_index);
@ -640,11 +676,17 @@ class NNAPIOpBuilder {
break;
case kTfLiteUInt8:
case kTfLiteInt8:
nn_type = (tensor_type == kTfLiteUInt8)
// If explicit int8 conversion is needed, we still need
// ANEURALNETWORKS_TENSOR_QUANT8_ASYMM type.
nn_type = (tensor_type == kTfLiteUInt8 || need_int8_conversion)
? ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
: ANEURALNETWORKS_TENSOR_QUANT8_SYMM;
scale = tensor->params.scale;
zeroPoint = tensor->params.zero_point;
if (need_int8_conversion) {
zeroPoint += 128;
operand_mapping_->add_type_conversion(tensor_index, kTfLiteUInt8);
}
if (scale == 0) {
// TENSOR_QUANT8_ASYMM and ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
// with zero scale are not valid in NNAPI.
@ -672,6 +714,26 @@ class NNAPIOpBuilder {
tensor_rank = 1;
tensor_dims = &tensor_rank;
}
ANeuralNetworksSymmPerChannelQuantParams ann_perchannel_params;
if (tensor_type == kTfLiteInt8 || tensor_type == kTfLiteUInt8) {
if (tensor->quantization.type == kTfLiteAffineQuantization) {
TfLiteAffineQuantization* quantization_params =
static_cast<TfLiteAffineQuantization*>(tensor->quantization.params);
if (quantization_params->scale->size > 1) {
// Set up per-channel quantization.
ann_perchannel_params = {
.channelDim = static_cast<uint32_t>(
quantization_params->quantized_dimension),
.scaleCount =
static_cast<uint32_t>(quantization_params->scale->size),
.scales = quantization_params->scale->data,
};
nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL;
scale = 0.0f;
zeroPoint = 0;
}
}
}
ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims,
scale, zeroPoint};
@ -679,6 +741,12 @@ class NNAPIOpBuilder {
context_,
nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
if (nn_type == ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL) {
RETURN_TFLITE_ERROR_IF_NN_ERROR(
context_,
nnapi_->ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
nn_model_, ann_tensor_index, &ann_perchannel_params));
}
if (tensor->allocation_type == kTfLiteMmapRo) {
// TODO(b/80630405): Use NNAPIAllocation.
RETURN_TFLITE_ERROR_IF_NN_ERROR(
@ -816,13 +884,24 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinConv2d:
if (version <= 2) {
// TODO(b/134571155): Fix hybrid and per-channel support in API 29.
if ((android_sdk_version <= kMinSdkVersionForNNAPI12) &&
if ((android_sdk_version < kMinSdkVersionForNNAPI12) &&
(IsHybridOperator(context, builtin_code, node) ||
!IsFloatOrUint8Operator(context, node))) {
// Hybrid operators not supported before NNAPI 1.2.
return nullptr;
}
if (android_sdk_version < kMinSdkVersionForNNAPI12) {
// Per-channel quantized convolution not supported before NNAPI 1.2.
const auto& filter_tensor = context->tensors[node->inputs->data[1]];
if (filter_tensor.quantization.type == kTfLiteAffineQuantization) {
TfLiteAffineQuantization* quantization_params =
static_cast<TfLiteAffineQuantization*>(
filter_tensor.quantization.params);
if (quantization_params->scale->size > 1) {
return nullptr;
}
}
}
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
input_type == kTfLiteUInt8 &&
@ -874,8 +953,7 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinDepthwiseConv2d:
if (version == 1) {
// TODO(b/134571155): Fix per-channel support in API 29.
if (android_sdk_version <= kMinSdkVersionForNNAPI12 &&
if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
!IsFloatOrUint8Operator(context, node)) {
return nullptr;
}
@ -927,8 +1005,7 @@ class NNAPIDelegateKernel {
if (output_type == kTfLiteInt16) {
return nullptr;
}
// TODO(b/134571155): Fix hybrid and per-channel support in API 29.
if (android_sdk_version <= kMinSdkVersionForNNAPI12 &&
if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
(IsHybridOperator(context, builtin_code, node) ||
!IsFloatOrUint8Operator(context, node))) {
// Hybrid operators not supported before NNAPI 1.2.
@ -1055,9 +1132,6 @@ class NNAPIDelegateKernel {
!IsFloatOperator(context, node)) {
return nullptr;
}
if (!IsFloatOrUint8Operator(context, node)) {
return nullptr;
}
auto builtin =
reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
if (builtin->activation == kTfLiteActNone) {
@ -1224,9 +1298,9 @@ class NNAPIDelegateKernel {
(context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 ||
android_sdk_version >= kMinSdkVersionForNNAPI12)) {
// NNAPI does not support specifying the padding value.
// Before 1.2, NNAPI pads physical zero for quantized tensors,
// so only delegate float pad to NNAPI. NNAPI 1.2 onwards pads
// with zero-point, so delegate quantized pad as well.
// Before 1.2, NNAPI pads physical zero for quantized tensors, so
// only delegate float pad to NNAPI. NNAPI 1.2 onwards pads with
// zero-point, so delegate quantized pad as well.
return BasicMappingFn<ANEURALNETWORKS_PAD>;
} else if (node->inputs->size == 3 &&
android_sdk_version >= kMinSdkVersionForNNAPI12) {
@ -1822,6 +1896,15 @@ class NNAPIDelegateKernel {
input_offset)[i] =
static_cast<const int32_t>(tensor->data.raw_const[i]);
}
} else if (tensor->type == kTfLiteInt8 &&
ann_type_equivalent == kTfLiteUInt8) {
// Explicitly convert int8 values to uint8 values.
uint8_t* input_ptr = reinterpret_cast<uint8_t*>(
nn_input_memory_->get_data_ptr() + input_offset);
for (int i = 0; i < NumElements(tensor); ++i) {
input_ptr[i] = static_cast<const uint8_t>(
static_cast<int32_t>(tensor->data.int8[i]) + 128);
}
} else {
context->ReportError(
context,
@ -1917,6 +2000,17 @@ class NNAPIDelegateKernel {
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
continue;
}
TfLiteType ann_type_equivalent =
operand_mapping_.lite_index_to_ann_type_conversion(output_index);
if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
// Explicitly convert uint8 values to int8 values.
uint8_t* output_ptr = reinterpret_cast<uint8_t*>(
nn_output_memory_->get_data_ptr() + output_offset);
for (int i = 0; i < NumElements(tensor); ++i) {
output_ptr[i] =
static_cast<uint8_t>(static_cast<int32_t>(output_ptr[i]) - 128);
}
}
memcpy(tensor->data.raw,
nn_output_memory_->get_data_ptr() + output_offset, tensor->bytes);
output_offset += tensor->bytes;
@ -1996,7 +2090,7 @@ class NNAPIDelegateKernel {
const TfLiteType type = context->tensors[tensor_id].type;
// Nothing to do for this tensor if it's not quantized.
if (type != kTfLiteUInt8) continue;
if (!IsQuantized(type)) continue;
// Insert Dequantize operator if it hasn't been done already and change
// the node's input accordingly.
@ -2020,10 +2114,25 @@ class NNAPIDelegateKernel {
const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node);
const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code);
const bool need_int8_conversion =
NeedInt8Conversion(context, reg->builtin_code, node);
int input_tensor_flags = 0;
if (scalar_as_tensor) {
input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
}
// Map inputs to NN API tensor indices.
for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
const auto input_index = node->inputs->data[input_pos];
if (need_int8_conversion &&
(input_pos == 0 ||
reg->builtin_code == kTfLiteBuiltinFullyConnected)) {
// Only selected inputs require int8 conversion.
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
input_index, hybrid_op,
input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION));
continue;
}
if (reg->builtin_code == kTfLiteBuiltinLstm && input_pos >= 20) {
// Skip layer normalization weights. They are added in the Map
// function (after all the other inputs added there) since layer
@ -2031,8 +2140,8 @@ class NNAPIDelegateKernel {
// NNAPI.
continue;
}
// Pad and Padv2 have an optional parameter for a pad value which has to
// be converted to a scalar type in NN API.
// Pad and Padv2 have an optional parameter for a pad value which has
// to be converted to a scalar type in NN API.
if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
reg->builtin_code == kTfLiteBuiltinPad) &&
node->inputs->size == 3 && input_pos == 2) {
@ -2080,10 +2189,10 @@ class NNAPIDelegateKernel {
TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
} else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear) {
if (input_pos == 0) {
// Only the first input tensor is added. The second one, specifying
// the output height and width, is not added and instead the height
// and width will be added individually as scalars by the mapping
// function returned by Map().
// Only the first input tensor is added. The second one,
// specifying the output height and width, is not added and
// instead the height and width will be added individually as
// scalars by the mapping function returned by Map().
TF_LITE_ENSURE_STATUS(
builder.AddTensorInput(input_index, hybrid_op));
}
@ -2097,8 +2206,8 @@ class NNAPIDelegateKernel {
// have different order.
continue;
} else {
TF_LITE_ENSURE_STATUS(
builder.AddTensorInput(input_index, hybrid_op, scalar_as_tensor));
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op,
input_tensor_flags));
}
}
// Get op type and operands
@ -2107,8 +2216,13 @@ class NNAPIDelegateKernel {
node)({context, &builder, node, &model_state_outputs_,
&model_state_tfl_inputs_});
// Map outputs to NN API tensor indices.
int output_tensor_flags = 0;
if (need_int8_conversion) {
output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION;
}
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
TF_LITE_ENSURE_STATUS(
builder.AddTensorOutput(output_index, output_tensor_flags));
}
// Dequantize operators may have to be added in case inputs are to be

View File

@ -971,7 +971,8 @@ class HybridConvolutionOpModel : public BaseConvolutionOpModel {
TEST_P(ConvolutionOpTest, SimpleTestHybridUint8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
{TensorType_UINT8, {3, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0},
{TensorType_FLOAT32, {}});
m.SetInput({
// First batch
@ -1030,7 +1031,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridUint8) {
TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsUint8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}});
{TensorType_UINT8, {3, 2, 2, 2}, 0, 0, 4.0 / 127.0, 0},
{TensorType_FLOAT32, {}});
m.SetInput({
// First batch
@ -1062,7 +1064,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsUint8) {
TEST_P(ConvolutionOpTest, PointwiseHybridUint8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
{TensorType_UINT8, {1, 1, 1, 2}, 0, 0, 2.0 / 127.0, 0},
{TensorType_FLOAT32, {}}, 1, 1);
m.SetInput({
// First batch
@ -1104,7 +1107,8 @@ TEST_P(ConvolutionOpTest, PointwiseHybridUint8) {
TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_INT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
{TensorType_INT8, {3, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0},
{TensorType_FLOAT32, {}});
m.SetInput({
// First batch
@ -1161,9 +1165,20 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) {
//
// 2 * (A/2) * B = A * B, where the left side is this new test.
TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsInt8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_INT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}});
HybridConvolutionOpModel m(GetRegistration(),
{TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_INT8,
{3, 2, 2, 2},
0,
0,
0,
0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/
{4.0 / 127.0, 4.0 / 127.0, 4.0 / 127.0},
/*per_channel_quantization_offsets=*/{0, 0, 0},
/*channel_index=*/0},
{TensorType_FLOAT32, {}});
m.SetInput({
// First batch
@ -1195,7 +1210,8 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannelsInt8) {
TEST_P(ConvolutionOpTest, PointwiseHybridInt8) {
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_INT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
{TensorType_INT8, {1, 1, 1, 2}, 0, 0, 2.0 / 127.0, 0},
{TensorType_FLOAT32, {}}, 1, 1);
m.SetInput({
// First batch

View File

@ -708,7 +708,8 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedUint8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/{TensorType_UINT8, {3, 10}, -63.5, 64}); // Hybrid
/*weights=*/
{TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
@ -736,7 +737,7 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/{TensorType_INT8, {3, 10}, -63.5, 64}); // Hybrid
/*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}); // Hybrid
m.SetSignedWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0

View File

@ -41,6 +41,7 @@ enum {
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
ANEURALNETWORKS_BOOL = 6,
ANEURALNETWORKS_TENSOR_BOOL8 = 9,
ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11,
ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13,
};