Merge pull request #35946 from wwwind:conv2d_16x8
PiperOrigin-RevId: 307851318 Change-Id: Ie0d1c6dfcb3b6eca6c7b55a86c3a1b8fc9d407a9
This commit is contained in:
commit
114b8ef31a
tensorflow/lite/kernels
@ -320,9 +320,9 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
|
||||
// Check types. (We assume that UINT8 refers to quantized tensors)
|
||||
TfLiteType input_type = input->type;
|
||||
TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
|
||||
input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE(context,
|
||||
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, input_type);
|
||||
|
||||
const TfLiteTensor* bias = nullptr;
|
||||
@ -336,6 +336,11 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
} else if (input_type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, input_type);
|
||||
}
|
||||
@ -677,6 +682,42 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
TfLiteTensor* im2col) {
|
||||
ConvParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.stride_height = params->stride_height;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
|
||||
switch (kernel_type) {
|
||||
case kGenericOptimized:
|
||||
case kMultithreadOptimized:
|
||||
case kCblasOptimized:
|
||||
case kReference: {
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int16>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<std::int64_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int16>(output));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
@ -938,6 +979,10 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
EvalQuantizedPerChannel<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, output, im2col);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
EvalQuantizedPerChannel16x8<kernel_type>(
|
||||
context, node, params, data, input, filter, bias, output, im2col);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Type %s currently not supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
@ -957,6 +1002,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalImpl<kernel_type, kTfLiteUInt8>(context, node);
|
||||
case kTfLiteInt8:
|
||||
return EvalImpl<kernel_type, kTfLiteInt8>(context, node);
|
||||
case kTfLiteInt16:
|
||||
return EvalImpl<kernel_type, kTfLiteInt16>(context, node);
|
||||
default:
|
||||
context->ReportError(context, "Type %d not currently supported.",
|
||||
input->type);
|
||||
|
@ -70,7 +70,12 @@ class BaseConvolutionOpModel : public SingleOpModel {
|
||||
input.scale * filter.per_channel_quantization_scales[i];
|
||||
bias_zero_points[i] = 0;
|
||||
}
|
||||
TensorData bias{TensorType_INT32,
|
||||
tflite::TensorType bias_type = TensorType_INT32;
|
||||
if (input.type == TensorType_INT16) {
|
||||
// In case of 16-bit, the bias type is set to be int 64.
|
||||
bias_type = TensorType_INT64;
|
||||
}
|
||||
TensorData bias{bias_type,
|
||||
{bias_size},
|
||||
/*min=*/0,
|
||||
/*max=*/0,
|
||||
|
@ -894,6 +894,23 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "conv_per_channel_quantized_16x8_test",
|
||||
srcs = [
|
||||
"conv_per_channel_quantized_16x8_test.cc",
|
||||
],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
":common",
|
||||
":optimized_base",
|
||||
":quantization_util",
|
||||
":reference_base",
|
||||
":test_util",
|
||||
":types",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "resize_bilinear_test",
|
||||
srcs = ["resize_bilinear_test.cc"],
|
||||
|
@ -160,6 +160,27 @@ inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier,
|
||||
right_shift);
|
||||
}
|
||||
|
||||
inline int32 MultiplyByQuantizedMultiplier(int64_t x,
|
||||
int32 quantized_multiplier,
|
||||
int shift) {
|
||||
// Inputs:
|
||||
// - quantized_multiplier has fixed point at bit 31
|
||||
// - shift is -31 to +7 (negative for right shift)
|
||||
//
|
||||
// Assumptions: The following input ranges are assumed
|
||||
// - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
|
||||
// - scaling is chosen so final scaled result fits in int32
|
||||
// - input x is in the range -(1<<47) <= x < (1<<47)
|
||||
assert(quantized_multiplier >= 0);
|
||||
assert(shift >= -31 && shift < 8);
|
||||
|
||||
int32_t reduced_multiplier = (quantized_multiplier + (1 << 15)) >> 16;
|
||||
int total_shift = 15 - shift;
|
||||
x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
|
||||
int32_t result = x >> total_shift;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int CountLeadingZeros(T integer_input) {
|
||||
static_assert(std::is_unsigned<T>::value,
|
||||
|
@ -0,0 +1,332 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdio.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/test_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
void PickOutputMultiplier(
|
||||
const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& filter_shape,
|
||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
float* output_multiplier) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
const int dilation_height_factor = params.dilation_height_factor;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_depth = output_shape.Dims(3);
|
||||
|
||||
std::int64_t output_accu_min = std::numeric_limits<std::int64_t>::max();
|
||||
std::int64_t output_accu_max = std::numeric_limits<std::int64_t>::min();
|
||||
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int output_channel = 0; output_channel < output_depth;
|
||||
++output_channel) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
const int in_y_origin = (out_y * stride_height) - pad_height;
|
||||
std::int64_t acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
const int in_x = in_x_origin + dilation_width_factor * filter_x;
|
||||
const int in_y =
|
||||
in_y_origin + dilation_height_factor * filter_y;
|
||||
// Zero padding by omitting the areas outside the image.
|
||||
const bool is_point_inside_image =
|
||||
(in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
|
||||
(in_y < input_height);
|
||||
if (is_point_inside_image) {
|
||||
int32 input_val = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
int32 filter_val =
|
||||
filter_data[Offset(filter_shape, output_channel, filter_y,
|
||||
filter_x, in_channel)];
|
||||
acc += static_cast<std::int64_t>(filter_val) *
|
||||
static_cast<std::int64_t>(input_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias_data) {
|
||||
acc += bias_data[output_channel];
|
||||
}
|
||||
output_accu_max = std::max(acc, output_accu_max);
|
||||
output_accu_min = std::min(acc, output_accu_min);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Since int16 ranges from -32768 to 32767, we need to squeeze the accumulator
|
||||
// min/max fit in those ranges correspondingly as much as possible.
|
||||
if (std::abs(output_accu_max) > std::abs(output_accu_min)) {
|
||||
*output_multiplier = 32767.0f / std::abs(output_accu_max);
|
||||
} else {
|
||||
*output_multiplier = 32768.0f / std::abs(output_accu_min);
|
||||
}
|
||||
}
|
||||
|
||||
void PickReasonableMultiplier(
|
||||
const ConvParams& params, int output_activation_min,
|
||||
int output_activation_max, int output_depth,
|
||||
const RuntimeShape& input_shape_inference, const std::int16_t* input_data,
|
||||
const RuntimeShape& filter_shape_inference, const std::int8_t* filter_data,
|
||||
const RuntimeShape& bias_shape_inference, const std::int64_t* bias_data,
|
||||
const RuntimeShape& output_shape_inference,
|
||||
std::int32_t* output_multiplier_ptr, std::int32_t* output_shift_ptr,
|
||||
std::int16_t* output_data) {
|
||||
float output_multiplier;
|
||||
PickOutputMultiplier(params, input_shape_inference, input_data,
|
||||
filter_shape_inference, filter_data,
|
||||
bias_shape_inference, bias_data, output_shape_inference,
|
||||
&output_multiplier);
|
||||
|
||||
int base_multiplier;
|
||||
int base_shift;
|
||||
QuantizeMultiplier(output_multiplier, &base_multiplier, &base_shift);
|
||||
for (int i = 0; i < output_depth; ++i) {
|
||||
// multipliers typically range in [2^30 ; 2^31 - 1].
|
||||
// Values in [0, 2^30 - 1] are normally unused, but harmless.
|
||||
// Thus a good way to randomize multipliers is to subtract from them
|
||||
// a random value smaller than 2^30 but still significant compared to it.
|
||||
output_multiplier_ptr[i] = base_multiplier - (std::rand() % (1 << 26));
|
||||
output_shift_ptr[i] = base_shift - 1 + (std::rand() % 4);
|
||||
}
|
||||
}
|
||||
|
||||
bool GenerateValidShapeConfigurations(
|
||||
int filter_width, int filter_height, int dilation_width_factor,
|
||||
int dilation_height_factor, RuntimeShape* input_shape_inference,
|
||||
RuntimeShape* filter_shape_inference, RuntimeShape* output_shape_inference,
|
||||
int* pad_width, int* pad_height, int* stride) {
|
||||
const int batch = UniformRandomInt(1, 3);
|
||||
const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
|
||||
const int input_width = UniformRandomInt(5, 50);
|
||||
const int input_height = UniformRandomInt(5, 50);
|
||||
*stride = UniformRandomInt(1, 2);
|
||||
const bool test_pad = UniformRandomInt(0, 1);
|
||||
const auto padding_type = test_pad ? PaddingType::kValid : PaddingType::kSame;
|
||||
|
||||
const int output_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
|
||||
|
||||
input_shape_inference->BuildFrom(
|
||||
{batch, input_height, input_width, input_depth});
|
||||
|
||||
filter_shape_inference->BuildFrom(
|
||||
{output_depth, filter_height, filter_width, input_depth});
|
||||
|
||||
EXPECT_TRUE(ComputeConvSizes(
|
||||
*input_shape_inference, output_depth, filter_width, filter_height,
|
||||
*stride, dilation_width_factor, dilation_height_factor, padding_type,
|
||||
output_shape_inference, pad_width, pad_height));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void IntToFloat(std::vector<float>* d, std::vector<std::int8_t>* s) {
|
||||
for (unsigned int i = 0; i < s->size(); i++) {
|
||||
d->data()[i] = (float)s->data()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void IntToFloat(std::vector<float>* d, std::vector<std::int64_t>* s) {
|
||||
for (unsigned int i = 0; i < s->size(); i++) {
|
||||
d->data()[i] = (float)s->data()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void TryTestOneConvFilter(int test_num) {
|
||||
const int filter_width = UniformRandomInt(2, 5);
|
||||
const int filter_height = UniformRandomInt(2, 5);
|
||||
std::cout << "Test number " << test_num << " (" << filter_width << ","
|
||||
<< filter_height << ")\n";
|
||||
// We don't support dilations in the 3x3 filter.
|
||||
const int dilation_width_factor = 1;
|
||||
const int dilation_height_factor = 1;
|
||||
|
||||
const int output_activation_min = -32768;
|
||||
const int output_activation_max = 32767;
|
||||
|
||||
RuntimeShape input_shape_inference;
|
||||
RuntimeShape filter_shape_inference;
|
||||
RuntimeShape output_shape_inference;
|
||||
int pad_width, pad_height;
|
||||
int stride;
|
||||
|
||||
// Keeps trying until we get valid shape/configurations for 3x3 filter case.
|
||||
bool generated_valid_configurations_for_3x3_kernel = false;
|
||||
while (!generated_valid_configurations_for_3x3_kernel) {
|
||||
generated_valid_configurations_for_3x3_kernel =
|
||||
GenerateValidShapeConfigurations(
|
||||
filter_width, filter_height, dilation_width_factor,
|
||||
dilation_height_factor, &input_shape_inference,
|
||||
&filter_shape_inference, &output_shape_inference, &pad_width,
|
||||
&pad_height, &stride);
|
||||
}
|
||||
|
||||
const int output_depth = output_shape_inference.Dims(3);
|
||||
|
||||
RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
|
||||
const int input_buffer_size = input_shape_inference.FlatSize();
|
||||
const int filter_buffer_size = filter_shape_inference.FlatSize();
|
||||
const int output_buffer_size = output_shape_inference.FlatSize();
|
||||
std::vector<std::int16_t> input_data(input_buffer_size);
|
||||
std::vector<std::int8_t> filter_data(filter_buffer_size);
|
||||
std::vector<std::int64_t> bias_data(output_depth);
|
||||
|
||||
if (test_num & 1) {
|
||||
// Use high values samples to give large accumulator
|
||||
FillRandom(&input_data, (std::int16_t)32700, (std::int16_t)32767);
|
||||
FillRandom(&filter_data, (std::int8_t)120, (std::int8_t)127);
|
||||
} else {
|
||||
FillRandom(&input_data);
|
||||
FillRandom(&filter_data);
|
||||
}
|
||||
for (int i = 0; i < output_depth; i++) {
|
||||
bias_data.data()[i] = 0;
|
||||
}
|
||||
|
||||
ConvParams params;
|
||||
params.stride_width = stride;
|
||||
params.stride_height = stride;
|
||||
params.dilation_height_factor = dilation_height_factor;
|
||||
params.dilation_width_factor = dilation_width_factor;
|
||||
params.padding_values.width = pad_width;
|
||||
params.padding_values.height = pad_height;
|
||||
params.weights_offset = 0;
|
||||
params.quantized_activation_min = output_activation_min;
|
||||
params.quantized_activation_max = output_activation_max;
|
||||
params.float_activation_max = (float)(1LL << 40);
|
||||
params.float_activation_min = -params.float_activation_max;
|
||||
|
||||
std::vector<std::int16_t> reference_output_data(output_buffer_size);
|
||||
std::vector<std::int16_t> neon_output_data(output_buffer_size);
|
||||
|
||||
std::vector<std::int32_t> output_multiplier(output_depth);
|
||||
std::vector<std::int32_t> output_shift(output_depth);
|
||||
|
||||
// It's hard to come up with a right multiplier, random guess basically makes
|
||||
// all the results saturated and becomes meaningfulless, so we first use
|
||||
// reference impl to poke the min/max value of the accumulation, then use that
|
||||
// value as a guided suggestion for us to populate meaningful mulitplier &
|
||||
// shift.
|
||||
PickReasonableMultiplier(
|
||||
params, output_activation_min, output_activation_max, output_depth,
|
||||
input_shape_inference, input_data.data(), filter_shape_inference,
|
||||
filter_data.data(), bias_shape_inference, bias_data.data(),
|
||||
output_shape_inference, output_multiplier.data(), output_shift.data(),
|
||||
reference_output_data.data());
|
||||
|
||||
// The following tests compare referene impl and Neon general impl agrees,
|
||||
// and reference impl loosely agrees with fast kernel since they use different
|
||||
// rounding strategy.
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
params, output_multiplier.data(), output_shift.data(),
|
||||
input_shape_inference, input_data.data(), filter_shape_inference,
|
||||
filter_data.data(), bias_shape_inference, bias_data.data(),
|
||||
output_shape_inference, reference_output_data.data());
|
||||
|
||||
std::vector<float> input_data_float(input_buffer_size);
|
||||
std::vector<float> filter_data_float(filter_buffer_size);
|
||||
std::vector<float> bias_data_float(output_depth);
|
||||
std::vector<float> output_data_float(output_buffer_size);
|
||||
|
||||
for (int i = 0; i < input_buffer_size; i++) {
|
||||
input_data_float.data()[i] = (float)(input_data.data()[i]);
|
||||
}
|
||||
IntToFloat(&filter_data_float, &filter_data);
|
||||
IntToFloat(&bias_data_float, &bias_data);
|
||||
RuntimeShape im2col_shape;
|
||||
float im2col_data;
|
||||
|
||||
reference_ops::Conv(params, input_shape_inference, input_data_float.data(),
|
||||
filter_shape_inference, filter_data_float.data(),
|
||||
bias_shape_inference, bias_data_float.data(),
|
||||
output_shape_inference, output_data_float.data(),
|
||||
im2col_shape, &im2col_data);
|
||||
|
||||
for (int n = 0; n < output_shape_inference.Dims(0); n++) {
|
||||
for (int h = 0; h < output_shape_inference.Dims(1); h++) {
|
||||
for (int w = 0; w < output_shape_inference.Dims(2); w++) {
|
||||
for (int c = 0; c < output_shape_inference.Dims(3); c++) {
|
||||
int offset = Offset(output_shape_inference, n, h, w, c);
|
||||
float float_res = output_data_float.data()[offset];
|
||||
int16 int16_res = reference_output_data.data()[offset];
|
||||
int32 output_mul = output_multiplier.data()[c];
|
||||
int shift = output_shift.data()[c];
|
||||
float scale = (float)output_mul / (float)(1ULL << 31);
|
||||
if (shift > 0) scale = scale * (float)(1 << shift);
|
||||
if (shift < 0) scale = scale / (float)(1 << -shift);
|
||||
int ref_res = floor(float_res * scale + 0.5);
|
||||
if (ref_res < output_activation_min) ref_res = output_activation_min;
|
||||
if (ref_res > output_activation_max) ref_res = output_activation_max;
|
||||
int e = (ref_res - int16_res);
|
||||
if (e < 0) e = -e;
|
||||
if (e > 2) {
|
||||
ADD_FAILURE() << "(" << n << ", " << h << ", " << w << ", " << c
|
||||
<< ")"
|
||||
<< " scale=" << output_mul << " shift=" << shift
|
||||
<< " res=" << int16_res
|
||||
<< " float=" << float_res * scale << " (" << float_res
|
||||
<< ", " << scale << ")";
|
||||
EXPECT_TRUE(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QuantizedConvPerChannelTest, FastKernelTest) {
|
||||
for (int i = 0; i < 30; ++i) {
|
||||
TryTestOneConvFilter(i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -122,6 +122,95 @@ inline void ConvPerChannel(
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed-point per-channel-quantization convolution reference kernel.
|
||||
// 16-bit data and 8-bit filter
|
||||
inline void ConvPerChannel(
|
||||
const ConvParams& params, const int32* output_multiplier,
|
||||
const int32* output_shift, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& filter_shape,
|
||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data) {
|
||||
// Get parameters.
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
const int dilation_height_factor = params.dilation_height_factor;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
|
||||
// Set min and max value of the output.
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
const int32 output_activation_max = params.quantized_activation_max;
|
||||
|
||||
// Sanity check.
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
}
|
||||
|
||||
// Check dimensions of the tensors.
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||
const int in_y_origin = (out_y * stride_height) - pad_height;
|
||||
std::int64_t acc = 0;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
const int in_x = in_x_origin + dilation_width_factor * filter_x;
|
||||
const int in_y =
|
||||
in_y_origin + dilation_height_factor * filter_y;
|
||||
// Zero padding by omitting the areas outside the image.
|
||||
const bool is_point_inside_image =
|
||||
(in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
|
||||
(in_y < input_height);
|
||||
if (is_point_inside_image) {
|
||||
int32 input_val = input_data[Offset(input_shape, batch, in_y,
|
||||
in_x, in_channel)];
|
||||
int32 filter_val =
|
||||
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||
filter_x, in_channel)];
|
||||
// Accumulate with 64 bits accumulator.
|
||||
// int64 += int8 * int16 so the highest value we can
|
||||
// get from each accumulation is [-127, 127] * ([-32768,
|
||||
// 32767] -
|
||||
// [-32768, 32767]), which is [-8322945, 8322945].
|
||||
// log2(8322945) = 22.99.
|
||||
acc += filter_val * input_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias_data) {
|
||||
acc += bias_data[out_channel];
|
||||
}
|
||||
int32_t scaled_acc = MultiplyByQuantizedMultiplier(
|
||||
acc, output_multiplier[out_channel], output_shift[out_channel]);
|
||||
scaled_acc = std::max(scaled_acc, output_activation_min);
|
||||
scaled_acc = std::min(scaled_acc, output_activation_max);
|
||||
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||
static_cast<int16_t>(scaled_acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -62,8 +62,9 @@ TfLiteStatus PopulateConvolutionQuantizationParams(
|
||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
||||
const bool is_per_channel = affine_quantization->scale->size > 1;
|
||||
if (is_per_channel) {
|
||||
// Currently only Int8 is supported for per channel quantization.
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt8);
|
||||
// Currently only Int8/Int16 is supported for per channel quantization.
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, num_channels);
|
||||
TF_LITE_ENSURE_EQ(
|
||||
@ -104,7 +105,8 @@ TfLiteStatus PopulateConvolutionQuantizationParams(
|
||||
QuantizeMultiplier(real_multiplier, multiplier, &exponent);
|
||||
*shift = -exponent;
|
||||
}
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, activation, output, output_activation_min,
|
||||
output_activation_max));
|
||||
|
@ -251,32 +251,44 @@ class SingleOpModel {
|
||||
quantized_output.data() + quantized_output.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PerChannelQuantizeBiasPopulateTensor(
|
||||
const std::vector<float>& input_data, int index,
|
||||
TfLiteAffineQuantization* params) {
|
||||
const int32_t num_inputs = input_data.size();
|
||||
std::vector<T> quantized_output(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const float scale = params->scale->size == 1 ? params->scale->data[0]
|
||||
: params->scale->data[i];
|
||||
quantized_output[i] = input_data[i] / scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PerChannelQuantizeBiasPopulateTensor(
|
||||
int index, const std::vector<float>& input_data,
|
||||
const TfLiteAffineQuantization* params) {
|
||||
const int32_t num_inputs = input_data.size();
|
||||
std::vector<T> quantized_output(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const float scale = params->scale->size == 1 ? params->scale->data[0]
|
||||
: params->scale->data[i];
|
||||
quantized_output[i] = input_data[i] / scale;
|
||||
}
|
||||
PopulateTensor(index, /*offset=*/0, quantized_output.data(),
|
||||
quantized_output.data() + quantized_output.size());
|
||||
}
|
||||
|
||||
// Quantize and populate data for bias with per channel quantization.
|
||||
void PerChannelQuantizeBias(int index, const std::vector<float>& input_data) {
|
||||
const int32_t num_inputs = input_data.size();
|
||||
std::vector<int32_t> quantized_output(num_inputs);
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(t->quantization.params);
|
||||
CHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
|
||||
if (t->type == kTfLiteInt32) {
|
||||
std::vector<int32_t> quantized_output(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const float scale = params->scale->size == 1 ? params->scale->data[0]
|
||||
: params->scale->data[i];
|
||||
quantized_output[i] = input_data[i] / scale;
|
||||
}
|
||||
PopulateTensor(index, /*offset=*/0, quantized_output.data(),
|
||||
quantized_output.data() + quantized_output.size());
|
||||
PerChannelQuantizeBiasPopulateTensor<int32_t>(index, input_data, params);
|
||||
} else {
|
||||
std::vector<int64_t> quantized_output(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const float scale = params->scale->size == 1 ? params->scale->data[0]
|
||||
: params->scale->data[i];
|
||||
quantized_output[i] = input_data[i] / scale;
|
||||
}
|
||||
PopulateTensor(index, /*offset=*/0, quantized_output.data(),
|
||||
quantized_output.data() + quantized_output.size());
|
||||
PerChannelQuantizeBiasPopulateTensor<int64_t>(index, input_data, params);
|
||||
}
|
||||
}
|
||||
|
||||
@ -368,6 +380,14 @@ class SingleOpModel {
|
||||
template <typename T>
|
||||
void PopulateTensor(int index, int offset, T* begin, T* end) {
|
||||
T* v = interpreter_->typed_tensor<T>(index);
|
||||
if (!v) {
|
||||
auto* t = interpreter_->tensor(index);
|
||||
CHECK(t) << "No tensor with index " << index << ".";
|
||||
CHECK(t->data.raw) << "Empty data for tensor with index " << index << ".";
|
||||
CHECK(v) << "Type mismatch for tensor with index " << index
|
||||
<< ". Requested " << typeToTfLiteType<T>() << ", got "
|
||||
<< t->type;
|
||||
}
|
||||
memcpy(v + offset, begin, (end - begin) * sizeof(T));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user