Optimized per-channel convolution.
PiperOrigin-RevId: 239279531
This commit is contained in:
parent
bc6db71576
commit
3a3b1132d7
@ -24,12 +24,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemm_support.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
@ -495,27 +497,70 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||||
TfLiteConvParams* params, OpData* data,
|
TfLiteConvParams* params, OpData* data,
|
||||||
TfLiteTensor* input, TfLiteTensor* filter,
|
TfLiteTensor* input, TfLiteTensor* filter,
|
||||||
TfLiteTensor* bias, TfLiteTensor* output) {
|
TfLiteTensor* bias, TfLiteTensor* output,
|
||||||
ConvParams op_params;
|
TfLiteTensor* im2col) {
|
||||||
op_params.input_offset = input->params.zero_point;
|
KernelType effective_kernel_type;
|
||||||
op_params.output_offset = output->params.zero_point;
|
effective_kernel_type = kernel_type;
|
||||||
op_params.stride_height = params->stride_height;
|
|
||||||
op_params.stride_width = params->stride_width;
|
|
||||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
|
||||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
|
||||||
op_params.padding_values.height = data->padding.height;
|
|
||||||
op_params.padding_values.width = data->padding.width;
|
|
||||||
|
|
||||||
reference_integer_ops::ConvPerChannel(
|
// If not running on NEON we force a fallback to the reference kernels, until
|
||||||
op_params, data->per_channel_output_multiplier.data(),
|
// we have optimized support on other platforms.
|
||||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
#ifndef GEMMLOWP_NEON
|
||||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
effective_kernel_type = kReference;
|
||||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
#endif
|
||||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
|
||||||
GetTensorData<int8>(output));
|
switch (effective_kernel_type) {
|
||||||
|
case kReference: {
|
||||||
|
ConvParams op_params;
|
||||||
|
op_params.input_offset = -input->params.zero_point;
|
||||||
|
op_params.output_offset = output->params.zero_point;
|
||||||
|
op_params.stride_height = params->stride_height;
|
||||||
|
op_params.stride_width = params->stride_width;
|
||||||
|
op_params.dilation_height_factor = params->dilation_height_factor;
|
||||||
|
op_params.dilation_width_factor = params->dilation_width_factor;
|
||||||
|
op_params.padding_values.height = data->padding.height;
|
||||||
|
op_params.padding_values.width = data->padding.width;
|
||||||
|
|
||||||
|
reference_integer_ops::ConvPerChannel(
|
||||||
|
op_params, data->per_channel_output_multiplier.data(),
|
||||||
|
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||||
|
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||||
|
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||||
|
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||||
|
GetTensorData<int8>(output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kGenericOptimized:
|
||||||
|
case kMultithreadOptimized:
|
||||||
|
case kCblasOptimized: {
|
||||||
|
#ifdef GEMMLOWP_NEON
|
||||||
|
gemmlowp::GemmContext* gemm_context =
|
||||||
|
gemm_support::GetFromContext(context);
|
||||||
|
ConvParams op_params;
|
||||||
|
op_params.input_offset = -input->params.zero_point;
|
||||||
|
op_params.output_offset = output->params.zero_point;
|
||||||
|
op_params.stride_height = params->stride_height;
|
||||||
|
op_params.stride_width = params->stride_width;
|
||||||
|
op_params.dilation_height_factor = params->dilation_height_factor;
|
||||||
|
op_params.dilation_width_factor = params->dilation_width_factor;
|
||||||
|
op_params.padding_values.height = data->padding.height;
|
||||||
|
op_params.padding_values.width = data->padding.width;
|
||||||
|
|
||||||
|
optimized_integer_ops::ConvPerChannel(
|
||||||
|
op_params, data->per_channel_output_multiplier.data(),
|
||||||
|
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||||
|
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||||
|
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||||
|
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||||
|
GetTensorData<int8>(output), GetTensorShape(im2col),
|
||||||
|
GetTensorData<int8>(im2col), gemm_context);
|
||||||
|
#endif
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
@ -707,8 +752,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bias, im2col, hwcn_weights, output);
|
bias, im2col, hwcn_weights, output);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
|
EvalQuantizedPerChannel<kernel_type>(context, node, params, data, input,
|
||||||
output);
|
filter, bias, output, im2col);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context, "Type %d not currently supported.",
|
||||||
|
@ -1130,7 +1130,7 @@ class PerChannelQuantizedConvolutionOpModel : public BaseConvolutionOpModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(ConvolutionOpTest, SimpleTest) {
|
TEST_P(ConvolutionOpTest, SimplePerChannelTest) {
|
||||||
PerChannelQuantizedConvolutionOpModel m(
|
PerChannelQuantizedConvolutionOpModel m(
|
||||||
GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
|
GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
|
||||||
{TensorType_INT8,
|
{TensorType_INT8,
|
||||||
|
@ -175,6 +175,8 @@ cc_library(
|
|||||||
"optimized/depthwiseconv_float.h",
|
"optimized/depthwiseconv_float.h",
|
||||||
"optimized/depthwiseconv_uint8.h",
|
"optimized/depthwiseconv_uint8.h",
|
||||||
"optimized/depthwiseconv_uint8_3x3_filter.h",
|
"optimized/depthwiseconv_uint8_3x3_filter.h",
|
||||||
|
"optimized/im2col_utils.h",
|
||||||
|
"optimized/integer_ops/conv.h",
|
||||||
"optimized/optimized_ops.h",
|
"optimized/optimized_ops.h",
|
||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
@ -209,6 +211,7 @@ cc_library(
|
|||||||
"optimized/depthwiseconv_float.h",
|
"optimized/depthwiseconv_float.h",
|
||||||
"optimized/depthwiseconv_uint8.h",
|
"optimized/depthwiseconv_uint8.h",
|
||||||
"optimized/depthwiseconv_uint8_3x3_filter.h",
|
"optimized/depthwiseconv_uint8_3x3_filter.h",
|
||||||
|
"optimized/im2col_utils.h",
|
||||||
"optimized/legacy_optimized_ops.h",
|
"optimized/legacy_optimized_ops.h",
|
||||||
"optimized/optimized_ops.h",
|
"optimized/optimized_ops.h",
|
||||||
],
|
],
|
||||||
|
235
tensorflow/lite/kernels/internal/optimized/im2col_utils.h
Normal file
235
tensorflow/lite/kernels/internal/optimized/im2col_utils.h
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
|
||||||
|
|
||||||
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimized_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
|
||||||
|
int h, int b, int kheight, int kwidth,
|
||||||
|
int stride_width, int stride_height,
|
||||||
|
int pad_width, int pad_height,
|
||||||
|
int in_width, int in_height,
|
||||||
|
int in_depth, int single_buffer_length,
|
||||||
|
int buffer_id, const T* in_data,
|
||||||
|
T* conv_buffer_data, uint8 zero_byte) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
// This chunk of code reshapes all the inputs corresponding to
|
||||||
|
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
|
||||||
|
const int kwidth_times_indepth = kwidth * in_depth;
|
||||||
|
const int inwidth_times_indepth = in_width * in_depth;
|
||||||
|
const int ih_ungated_start = h * stride_height - pad_height;
|
||||||
|
const int ih_ungated_end = (ih_ungated_start + kheight);
|
||||||
|
const int ih_end = std::min(ih_ungated_end, in_height);
|
||||||
|
const int iw_ungated_start = w * stride_width - pad_width;
|
||||||
|
const int iw_ungated_end = (iw_ungated_start + kwidth);
|
||||||
|
const int iw_end = std::min(iw_ungated_end, in_width);
|
||||||
|
// If the patch is off the edge of the input image, skip writing those rows
|
||||||
|
// and columns from the patch into the output array.
|
||||||
|
const int h_offset = std::max(0, -ih_ungated_start);
|
||||||
|
const int w_offset = std::max(0, -iw_ungated_start);
|
||||||
|
const int ih_start = std::max(0, ih_ungated_start);
|
||||||
|
const int iw_start = std::max(0, iw_ungated_start);
|
||||||
|
const int single_row_num =
|
||||||
|
std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
|
||||||
|
const int output_row_offset = (buffer_id * single_buffer_length);
|
||||||
|
int out_offset =
|
||||||
|
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
|
||||||
|
int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
|
||||||
|
|
||||||
|
// Express all of the calculations as padding around the input patch.
|
||||||
|
const int top_padding = h_offset;
|
||||||
|
const int bottom_padding = (ih_ungated_end - ih_end);
|
||||||
|
const int left_padding = w_offset;
|
||||||
|
const int right_padding = (iw_ungated_end - iw_end);
|
||||||
|
assert(single_row_num ==
|
||||||
|
((kwidth - (left_padding + right_padding)) * in_depth));
|
||||||
|
|
||||||
|
// Write out zeroes to the elements representing the top rows of the input
|
||||||
|
// patch that are off the edge of the input image.
|
||||||
|
if (top_padding > 0) {
|
||||||
|
const int top_row_elements = (top_padding * kwidth * in_depth);
|
||||||
|
memset(conv_buffer_data + output_row_offset, zero_byte,
|
||||||
|
(top_row_elements * sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the patch is on the interior of the input image horizontally, just copy
|
||||||
|
// over the rows sequentially, otherwise add zero padding at the start or end.
|
||||||
|
if ((left_padding == 0) && (right_padding == 0)) {
|
||||||
|
for (int ih = ih_start; ih < ih_end; ++ih) {
|
||||||
|
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
|
||||||
|
single_row_num * sizeof(T));
|
||||||
|
out_offset += kwidth_times_indepth;
|
||||||
|
in_offset += inwidth_times_indepth;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ih = ih_start; ih < ih_end; ++ih) {
|
||||||
|
if (left_padding > 0) {
|
||||||
|
const int left_start = (out_offset - (left_padding * in_depth));
|
||||||
|
memset(conv_buffer_data + left_start, zero_byte,
|
||||||
|
(left_padding * in_depth * sizeof(T)));
|
||||||
|
}
|
||||||
|
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
|
||||||
|
single_row_num * sizeof(T));
|
||||||
|
if (right_padding > 0) {
|
||||||
|
const int right_start = (out_offset + single_row_num);
|
||||||
|
memset(conv_buffer_data + right_start, zero_byte,
|
||||||
|
(right_padding * in_depth * sizeof(T)));
|
||||||
|
}
|
||||||
|
out_offset += kwidth_times_indepth;
|
||||||
|
in_offset += inwidth_times_indepth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the bottom of the patch falls off the input image, pad the values
|
||||||
|
// representing those input rows with zeroes.
|
||||||
|
if (bottom_padding > 0) {
|
||||||
|
const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
|
||||||
|
const int bottom_start =
|
||||||
|
output_row_offset +
|
||||||
|
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
|
||||||
|
memset(conv_buffer_data + bottom_start, zero_byte,
|
||||||
|
(bottom_row_elements * sizeof(T)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
|
||||||
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& filter_shape,
|
||||||
|
const RuntimeShape& output_shape, T* im2col_data) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int dilation_width_factor = params.dilation_width_factor;
|
||||||
|
const int dilation_height_factor = params.dilation_height_factor;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
|
||||||
|
// For dilated convolution, the input pixels are not contiguous therefore we
|
||||||
|
// can't use the same opitimizations as Im2Col(). Though note this code would
|
||||||
|
// work fine for the non-dilated case too (though likely a bit slower).
|
||||||
|
gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
|
||||||
|
TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
|
||||||
|
TFLITE_DCHECK(im2col_data);
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
MatchingDim(output_shape, 3, filter_shape, 0);
|
||||||
|
|
||||||
|
// Construct the MxN sized im2col matrix.
|
||||||
|
// The rows M, are sub-ordered B x H x W
|
||||||
|
const RuntimeShape row_shape({1, batches, output_height, output_width});
|
||||||
|
// The columns, N, are sub-ordered Kh x Kw x Din
|
||||||
|
const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
|
||||||
|
// Use dimensions M and N to construct dims for indexing directly into im2col
|
||||||
|
const RuntimeShape im2col_shape(
|
||||||
|
{1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
|
||||||
|
|
||||||
|
// Loop through the output rows (B x H x W)
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
// Each im2col row is an output pixel. Arrange the input data in this
|
||||||
|
// row in an order we can conveniently multiply with the filter data.
|
||||||
|
int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
|
||||||
|
const int in_x_origin = (out_x * stride_width) - pad_width;
|
||||||
|
const int in_y_origin = (out_y * stride_height) - pad_height;
|
||||||
|
// Loop through all the pixels of the filter (Kh x Kw)
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
||||||
|
if ((in_y >= 0) && (in_y < input_height)) {
|
||||||
|
// Filter row is within the input data.
|
||||||
|
// Loop through all the filter pixels in this row.
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
const int in_x = in_x_origin + dilation_width_factor * filter_x;
|
||||||
|
int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
|
||||||
|
T* dst = im2col_data +
|
||||||
|
Offset(im2col_shape, 0, 0, row_offset, col_offset);
|
||||||
|
if ((in_x >= 0) && (in_x < input_width)) {
|
||||||
|
// Filter pixel is within the input, copy the input data.
|
||||||
|
T const* src =
|
||||||
|
input_data + Offset(input_shape, batch, in_y, in_x, 0);
|
||||||
|
memcpy(dst, src, input_depth * sizeof(T));
|
||||||
|
} else {
|
||||||
|
// Filter pixel is outside the input, zero it out.
|
||||||
|
memset(dst, zero_byte, input_depth * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Filter row is outside the input, zero out the entire filter row.
|
||||||
|
int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
|
||||||
|
T* dst = im2col_data +
|
||||||
|
Offset(im2col_shape, 0, 0, row_offset, col_offset);
|
||||||
|
memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
|
||||||
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("Im2col");
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = input_shape.Dims(3);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int output_depth = output_shape.Dims(3);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
|
||||||
|
int buffer_id = 0;
|
||||||
|
// Loop over the output nodes.
|
||||||
|
for (int b = 0; b < batches; ++b) {
|
||||||
|
for (int h = 0; h < output_height; ++h) {
|
||||||
|
for (int w = 0; w < output_width; ++w) {
|
||||||
|
ExtractPatchIntoBufferColumn(
|
||||||
|
input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
|
||||||
|
pad_width, pad_height, input_width, input_height, input_depth,
|
||||||
|
output_depth, buffer_id, input_data, output_data, zero_byte);
|
||||||
|
++buffer_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optimized_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
|
159
tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h
Normal file
159
tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
|
||||||
|
|
||||||
|
#ifdef GEMMLOWP_NEON
|
||||||
|
|
||||||
|
#include "fixedpoint/fixedpoint.h"
|
||||||
|
#include "public/gemmlowp.h"
|
||||||
|
#include "public/map.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimized_integer_ops {
|
||||||
|
|
||||||
|
struct GemmlowpOutputPipelineFixedPointPCLhs {
|
||||||
|
typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
|
||||||
|
ColVectorMap;
|
||||||
|
typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
|
||||||
|
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
||||||
|
gemmlowp::VectorShape::Col>,
|
||||||
|
gemmlowp::OutputStageClamp,
|
||||||
|
gemmlowp::OutputStageSaturatingCastToInt8>
|
||||||
|
Pipeline;
|
||||||
|
static Pipeline MakeExp(const int32* bias_data, int output_rows,
|
||||||
|
const int32 output_offset,
|
||||||
|
const int32* output_multiplier,
|
||||||
|
const int* output_left_shift,
|
||||||
|
int32 output_activation_min,
|
||||||
|
int32 output_activation_max) {
|
||||||
|
ColVectorMap bias_vector(bias_data, output_rows);
|
||||||
|
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||||
|
bias_addition_stage.bias_vector = bias_vector;
|
||||||
|
|
||||||
|
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
||||||
|
gemmlowp::VectorShape::Col>
|
||||||
|
quantize_down_stage;
|
||||||
|
quantize_down_stage.result_offset_after_shift = output_offset;
|
||||||
|
quantize_down_stage.result_fixedpoint_multiplier =
|
||||||
|
ColVectorMap(output_multiplier, output_rows);
|
||||||
|
quantize_down_stage.result_exponent =
|
||||||
|
ColVectorMap(output_left_shift, output_rows);
|
||||||
|
|
||||||
|
gemmlowp::OutputStageClamp clamp_stage;
|
||||||
|
clamp_stage.min = output_activation_min;
|
||||||
|
clamp_stage.max = output_activation_max;
|
||||||
|
gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
|
||||||
|
return std::make_tuple(bias_addition_stage, quantize_down_stage,
|
||||||
|
clamp_stage, saturating_cast_stage);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fixed-point per-channel-quantization convolution reference kernel.
|
||||||
|
inline void ConvPerChannel(
|
||||||
|
const ConvParams& params, const int32* output_multiplier,
|
||||||
|
const int32* output_shift, const RuntimeShape& input_shape,
|
||||||
|
const int8* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||||
|
const RuntimeShape& im2col_shape, int8* im2col_data,
|
||||||
|
gemmlowp::GemmContext* gemm_context) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int dilation_width_factor = params.dilation_width_factor;
|
||||||
|
const int dilation_height_factor = params.dilation_height_factor;
|
||||||
|
const int32 input_offset = params.input_offset;
|
||||||
|
const int32 output_offset = params.output_offset;
|
||||||
|
// Set min and max value of the output.
|
||||||
|
static constexpr int32 output_activation_min =
|
||||||
|
std::numeric_limits<int8_t>::min();
|
||||||
|
static constexpr int32 output_activation_max =
|
||||||
|
std::numeric_limits<int8_t>::max();
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
|
||||||
|
const int8* gemm_input_data = nullptr;
|
||||||
|
const RuntimeShape* gemm_input_shape = nullptr;
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const bool need_dilated_im2col =
|
||||||
|
dilation_width_factor != 1 || dilation_height_factor != 1;
|
||||||
|
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
|
||||||
|
filter_width != 1 || filter_height != 1;
|
||||||
|
const int8 input_zero_point = -input_offset;
|
||||||
|
TFLITE_DCHECK_GE(input_zero_point, output_activation_min);
|
||||||
|
TFLITE_DCHECK_LE(input_zero_point, output_activation_max);
|
||||||
|
const uint8 zero_point_byte =
|
||||||
|
*reinterpret_cast<const uint8*>(&input_zero_point);
|
||||||
|
if (need_dilated_im2col) {
|
||||||
|
TFLITE_DCHECK(im2col_data);
|
||||||
|
optimized_ops::DilatedIm2col(params, zero_point_byte, input_shape,
|
||||||
|
input_data, filter_shape, output_shape,
|
||||||
|
im2col_data);
|
||||||
|
gemm_input_data = im2col_data;
|
||||||
|
gemm_input_shape = &im2col_shape;
|
||||||
|
} else if (need_im2col) {
|
||||||
|
TFLITE_DCHECK(im2col_data);
|
||||||
|
optimized_ops::Im2col(params, filter_height, filter_width, zero_point_byte,
|
||||||
|
input_shape, input_data, im2col_shape, im2col_data);
|
||||||
|
gemm_input_data = im2col_data;
|
||||||
|
gemm_input_shape = &im2col_shape;
|
||||||
|
} else {
|
||||||
|
TFLITE_DCHECK(!im2col_data);
|
||||||
|
gemm_input_data = input_data;
|
||||||
|
gemm_input_shape = &input_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int gemm_input_rows = gemm_input_shape->Dims(3);
|
||||||
|
const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
|
||||||
|
const int filter_rows = filter_shape.Dims(0);
|
||||||
|
const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
|
||||||
|
const int output_rows = output_shape.Dims(3);
|
||||||
|
// See b/79927784.
|
||||||
|
// const int output_cols = FlatSizeSkipDim(output_shape, 3);
|
||||||
|
const int output_cols =
|
||||||
|
output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
|
||||||
|
TFLITE_DCHECK_EQ(output_rows, filter_rows);
|
||||||
|
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
|
||||||
|
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
|
||||||
|
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
|
||||||
|
gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::RowMajor> filter_matrix(
|
||||||
|
filter_data, filter_rows, filter_cols);
|
||||||
|
gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::ColMajor> input_matrix(
|
||||||
|
gemm_input_data, gemm_input_rows, gemm_input_cols);
|
||||||
|
gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
|
||||||
|
output_data, output_rows, output_cols);
|
||||||
|
|
||||||
|
const auto& output_pipeline = GemmlowpOutputPipelineFixedPointPCLhs::MakeExp(
|
||||||
|
bias_data, output_rows, output_offset, output_multiplier, output_shift,
|
||||||
|
output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
gemmlowp::GemmWithOutputPipeline<
|
||||||
|
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
|
||||||
|
gemm_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
|
/*filter_offset*/ 0, input_offset, output_pipeline);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optimized_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // GEMMLOWP_NEON
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "fixedpoint/fixedpoint.h"
|
#include "fixedpoint/fixedpoint.h"
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/round.h"
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
@ -1969,214 +1970,6 @@ inline void Mean(const tflite::MeanParams& op_params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
|
|
||||||
int h, int b, int kheight, int kwidth,
|
|
||||||
int stride_width, int stride_height,
|
|
||||||
int pad_width, int pad_height,
|
|
||||||
int in_width, int in_height,
|
|
||||||
int in_depth, int single_buffer_length,
|
|
||||||
int buffer_id, const T* in_data,
|
|
||||||
T* conv_buffer_data, uint8 zero_byte) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
|
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
|
||||||
// This chunk of code reshapes all the inputs corresponding to
|
|
||||||
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
|
|
||||||
const int kwidth_times_indepth = kwidth * in_depth;
|
|
||||||
const int inwidth_times_indepth = in_width * in_depth;
|
|
||||||
const int ih_ungated_start = h * stride_height - pad_height;
|
|
||||||
const int ih_ungated_end = (ih_ungated_start + kheight);
|
|
||||||
const int ih_end = std::min(ih_ungated_end, in_height);
|
|
||||||
const int iw_ungated_start = w * stride_width - pad_width;
|
|
||||||
const int iw_ungated_end = (iw_ungated_start + kwidth);
|
|
||||||
const int iw_end = std::min(iw_ungated_end, in_width);
|
|
||||||
// If the patch is off the edge of the input image, skip writing those rows
|
|
||||||
// and columns from the patch into the output array.
|
|
||||||
const int h_offset = std::max(0, -ih_ungated_start);
|
|
||||||
const int w_offset = std::max(0, -iw_ungated_start);
|
|
||||||
const int ih_start = std::max(0, ih_ungated_start);
|
|
||||||
const int iw_start = std::max(0, iw_ungated_start);
|
|
||||||
const int single_row_num =
|
|
||||||
std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
|
|
||||||
const int output_row_offset = (buffer_id * single_buffer_length);
|
|
||||||
int out_offset =
|
|
||||||
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
|
|
||||||
int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
|
|
||||||
|
|
||||||
// Express all of the calculations as padding around the input patch.
|
|
||||||
const int top_padding = h_offset;
|
|
||||||
const int bottom_padding = (ih_ungated_end - ih_end);
|
|
||||||
const int left_padding = w_offset;
|
|
||||||
const int right_padding = (iw_ungated_end - iw_end);
|
|
||||||
assert(single_row_num ==
|
|
||||||
((kwidth - (left_padding + right_padding)) * in_depth));
|
|
||||||
|
|
||||||
// Write out zeroes to the elements representing the top rows of the input
|
|
||||||
// patch that are off the edge of the input image.
|
|
||||||
if (top_padding > 0) {
|
|
||||||
const int top_row_elements = (top_padding * kwidth * in_depth);
|
|
||||||
memset(conv_buffer_data + output_row_offset, zero_byte,
|
|
||||||
(top_row_elements * sizeof(T)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the patch is on the interior of the input image horizontally, just copy
|
|
||||||
// over the rows sequentially, otherwise add zero padding at the start or end.
|
|
||||||
if ((left_padding == 0) && (right_padding == 0)) {
|
|
||||||
for (int ih = ih_start; ih < ih_end; ++ih) {
|
|
||||||
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
|
|
||||||
single_row_num * sizeof(T));
|
|
||||||
out_offset += kwidth_times_indepth;
|
|
||||||
in_offset += inwidth_times_indepth;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int ih = ih_start; ih < ih_end; ++ih) {
|
|
||||||
if (left_padding > 0) {
|
|
||||||
const int left_start = (out_offset - (left_padding * in_depth));
|
|
||||||
memset(conv_buffer_data + left_start, zero_byte,
|
|
||||||
(left_padding * in_depth * sizeof(T)));
|
|
||||||
}
|
|
||||||
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
|
|
||||||
single_row_num * sizeof(T));
|
|
||||||
if (right_padding > 0) {
|
|
||||||
const int right_start = (out_offset + single_row_num);
|
|
||||||
memset(conv_buffer_data + right_start, zero_byte,
|
|
||||||
(right_padding * in_depth * sizeof(T)));
|
|
||||||
}
|
|
||||||
out_offset += kwidth_times_indepth;
|
|
||||||
in_offset += inwidth_times_indepth;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the bottom of the patch falls off the input image, pad the values
|
|
||||||
// representing those input rows with zeroes.
|
|
||||||
if (bottom_padding > 0) {
|
|
||||||
const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
|
|
||||||
const int bottom_start =
|
|
||||||
output_row_offset +
|
|
||||||
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
|
|
||||||
memset(conv_buffer_data + bottom_start, zero_byte,
|
|
||||||
(bottom_row_elements * sizeof(T)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
|
|
||||||
const RuntimeShape& input_shape, const T* input_data,
|
|
||||||
const RuntimeShape& filter_shape,
|
|
||||||
const RuntimeShape& output_shape, T* im2col_data) {
|
|
||||||
const int stride_width = params.stride_width;
|
|
||||||
const int stride_height = params.stride_height;
|
|
||||||
const int dilation_width_factor = params.dilation_width_factor;
|
|
||||||
const int dilation_height_factor = params.dilation_height_factor;
|
|
||||||
const int pad_width = params.padding_values.width;
|
|
||||||
const int pad_height = params.padding_values.height;
|
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
|
||||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
|
||||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
|
||||||
|
|
||||||
// For dilated convolution, the input pixels are not contiguous therefore we
|
|
||||||
// can't use the same opitimizations as Im2Col(). Though note this code would
|
|
||||||
// work fine for the non-dilated case too (though likely a bit slower).
|
|
||||||
gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
|
|
||||||
TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
|
|
||||||
TFLITE_DCHECK(im2col_data);
|
|
||||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
|
||||||
const int input_height = input_shape.Dims(1);
|
|
||||||
const int input_width = input_shape.Dims(2);
|
|
||||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
|
||||||
const int filter_height = filter_shape.Dims(1);
|
|
||||||
const int filter_width = filter_shape.Dims(2);
|
|
||||||
const int output_height = output_shape.Dims(1);
|
|
||||||
const int output_width = output_shape.Dims(2);
|
|
||||||
MatchingDim(output_shape, 3, filter_shape, 0);
|
|
||||||
|
|
||||||
// Construct the MxN sized im2col matrix.
|
|
||||||
// The rows M, are sub-ordered B x H x W
|
|
||||||
const RuntimeShape row_shape({1, batches, output_height, output_width});
|
|
||||||
// The columns, N, are sub-ordered Kh x Kw x Din
|
|
||||||
const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
|
|
||||||
// Use dimensions M and N to construct dims for indexing directly into im2col
|
|
||||||
const RuntimeShape im2col_shape(
|
|
||||||
{1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
|
|
||||||
|
|
||||||
// Loop through the output rows (B x H x W)
|
|
||||||
for (int batch = 0; batch < batches; ++batch) {
|
|
||||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
|
||||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
|
||||||
// Each im2col row is an output pixel. Arrange the input data in this
|
|
||||||
// row in an order we can conveniently multiply with the filter data.
|
|
||||||
int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
|
|
||||||
const int in_x_origin = (out_x * stride_width) - pad_width;
|
|
||||||
const int in_y_origin = (out_y * stride_height) - pad_height;
|
|
||||||
// Loop through all the pixels of the filter (Kh x Kw)
|
|
||||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
|
||||||
const int in_y = in_y_origin + dilation_height_factor * filter_y;
|
|
||||||
if ((in_y >= 0) && (in_y < input_height)) {
|
|
||||||
// Filter row is within the input data.
|
|
||||||
// Loop through all the filter pixels in this row.
|
|
||||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
|
||||||
const int in_x = in_x_origin + dilation_width_factor * filter_x;
|
|
||||||
int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
|
|
||||||
T* dst = im2col_data +
|
|
||||||
Offset(im2col_shape, 0, 0, row_offset, col_offset);
|
|
||||||
if ((in_x >= 0) && (in_x < input_width)) {
|
|
||||||
// Filter pixel is within the input, copy the input data.
|
|
||||||
T const* src =
|
|
||||||
input_data + Offset(input_shape, batch, in_y, in_x, 0);
|
|
||||||
memcpy(dst, src, input_depth * sizeof(T));
|
|
||||||
} else {
|
|
||||||
// Filter pixel is outside the input, zero it out.
|
|
||||||
memset(dst, zero_byte, input_depth * sizeof(T));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Filter row is outside the input, zero out the entire filter row.
|
|
||||||
int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
|
|
||||||
T* dst = im2col_data +
|
|
||||||
Offset(im2col_shape, 0, 0, row_offset, col_offset);
|
|
||||||
memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
|
|
||||||
const RuntimeShape& input_shape, const T* input_data,
|
|
||||||
const RuntimeShape& output_shape, T* output_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Im2col");
|
|
||||||
const int stride_width = params.stride_width;
|
|
||||||
const int stride_height = params.stride_height;
|
|
||||||
const int pad_width = params.padding_values.width;
|
|
||||||
const int pad_height = params.padding_values.height;
|
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
|
||||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
|
||||||
|
|
||||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
|
||||||
const int input_depth = input_shape.Dims(3);
|
|
||||||
const int input_width = input_shape.Dims(2);
|
|
||||||
const int input_height = input_shape.Dims(1);
|
|
||||||
const int output_depth = output_shape.Dims(3);
|
|
||||||
const int output_width = output_shape.Dims(2);
|
|
||||||
const int output_height = output_shape.Dims(1);
|
|
||||||
|
|
||||||
int buffer_id = 0;
|
|
||||||
// Loop over the output nodes.
|
|
||||||
for (int b = 0; b < batches; ++b) {
|
|
||||||
for (int h = 0; h < output_height; ++h) {
|
|
||||||
for (int w = 0; w < output_width; ++w) {
|
|
||||||
ExtractPatchIntoBufferColumn(
|
|
||||||
input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
|
|
||||||
pad_width, pad_height, input_width, input_height, input_depth,
|
|
||||||
output_depth, buffer_id, input_data, output_data, zero_byte);
|
|
||||||
++buffer_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
const float* filter_data, const RuntimeShape& bias_shape,
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
|
@ -100,7 +100,7 @@ inline void ConvPerChannel(
|
|||||||
// we have seen so far.
|
// we have seen so far.
|
||||||
// TODO(jianlijianli): Add a check to make sure the
|
// TODO(jianlijianli): Add a check to make sure the
|
||||||
// accumulator depth is smaller than 2^16.
|
// accumulator depth is smaller than 2^16.
|
||||||
acc += filter_val * (input_val - input_offset);
|
acc += filter_val * (input_val + input_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,11 +216,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
|||||||
|
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "gemmlowp",
|
name = "gemmlowp",
|
||||||
sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
|
sha256 = "4da5404de25eeda40e7ceb18cf4ac1ce935db91c61ca2b4b84ef9d03e0ad1d4c",
|
||||||
strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
|
strip_prefix = "gemmlowp-1bf3b9c582c70bddb07b8004fc031d9765684f79",
|
||||||
urls = [
|
urls = [
|
||||||
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
|
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/1bf3b9c582c70bddb07b8004fc031d9765684f79.zip",
|
||||||
"https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
|
"https://github.com/google/gemmlowp/archive/1bf3b9c582c70bddb07b8004fc031d9765684f79.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user