Internal change on quantization.

PiperOrigin-RevId: 351876369
Change-Id: If0990e2fca9e90b4eee6dea23a3e2eeaeb6623fb
This commit is contained in:
Hyeonjong Ryu 2021-01-14 14:22:55 -08:00 committed by TensorFlower Gardener
parent 94f9284bfb
commit 7325daf339
3 changed files with 76 additions and 4 deletions
tensorflow/lite/kernels

View File

@ -965,8 +965,8 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = 1;
op_params.dilation_height_factor = 1;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
optimized_ops::HybridConv(

View File

@ -1295,6 +1295,69 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) {
0.16)));
}
TEST_P(ConvolutionOpTest, SimpleTestHybridInt8WithDilation) {
const int stride_width = 1;
const int stride_height = 1;
const Padding padding = Padding_VALID;
const int dilation_width_factor = 2;
const int dilation_height_factor = 1;
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_INT8, {3, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0},
{TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
ActivationFunctionType_NONE, dilation_width_factor,
dilation_height_factor);
m.SetInput({
// First batch
1, 1, 1, 1, // row = 1
2, 2, 2, 2, // row = 2
// Second batch
1, 2, 3, 4, // row = 1
1, 2, 3, 4, // row = 2
});
m.SetSignedFilter({
1, 2, 3, 4, // first 2x2 filter
-1, 1, -1, 1, // second 2x2 filter
-1, -1, 1, 1, // third 2x2 filter
});
m.SetBias({1, 2, 3});
m.Invoke();
// Example: we get 17.1577 instead of 17.
//
// Second batch:
// 1 2 3 4 -> 32 64 95 127 with scale factor 127/4.
// 1 2 3 4 32 64 95 127
//
// First filter:
// 1 2 -> 32 64 with scale factor of 127/4.
// 3 4 95 127
//
// The left half of the input gives us 16288. Multiply by (4/127)^2 for
// dequantization and adding 1 for the bias gives us the result. and adding
// the bias gives us the result.
//
// The optimized kernel converts the input into this matrix via Im2Col
//
// 1 1 2 2
// 1 1 2 2
// 1 3 1 3
// 2 4 2 4
//
// and multiplies it with the filter directly.
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
18, 2, 5, // first batch, left
18, 2, 5, // first batch, right
23, 6, 3, // second batch, left
33, 6, 3, // second batch, right
},
0.16)));
}
TEST_P(ConvolutionOpTest, SimpleTestHybridInt8Big) {
// A bigger variant of the simple hybrid test to ensure coverage on
// optimized paths that are only enabled at larger matrix sizes.

View File

@ -1342,6 +1342,8 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
int8_t* im2col_data, CpuBackendContext* context) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
@ -1352,15 +1354,22 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
const int filter_width = filter_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int input_zero_point = 0;
const int8_t* gemm_input_data = nullptr;
int num_input;
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_im2col) {
if (need_dilated_im2col) {
DilatedIm2col(params, input_zero_point, input_shape, input_data,
filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
num_input = im2col_shape.FlatSize();
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
// symmetric quantization assumes zero point of 0.
const int input_zero_point = 0;
Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
input_data, im2col_shape, im2col_data);