Copy depthwise_conv uint8 3x3 filter implementation

PiperOrigin-RevId: 247210491
This commit is contained in:
Renjie Liu 2019-05-08 07:39:54 -07:00 committed by TensorFlower Gardener
parent 61c8837163
commit cc0ad492d2
4 changed files with 2985 additions and 1 deletions

View File

@ -180,6 +180,7 @@ cc_library(
"optimized/integer_ops/add.h", "optimized/integer_ops/add.h",
"optimized/integer_ops/conv.h", "optimized/integer_ops/conv.h",
"optimized/integer_ops/depthwise_conv.h", "optimized/integer_ops/depthwise_conv.h",
"optimized/integer_ops/depthwise_conv_3x3_filter.h",
"optimized/integer_ops/fully_connected.h", "optimized/integer_ops/fully_connected.h",
"optimized/integer_ops/mul.h", "optimized/integer_ops/mul.h",
"optimized/integer_ops/pooling.h", "optimized/integer_ops/pooling.h",

View File

@ -319,12 +319,20 @@ template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
int32 kStrideWidth, int32 kStrideHeight> int32 kStrideWidth, int32 kStrideHeight>
struct DepthwiseConvWindow {}; struct DepthwiseConvWindow {};
template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
int32 kStrideWidth, int32 kStrideHeight>
struct DepthwiseConvWindowPerChannel {};
enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType, template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
int kPadWidth, int kPadHeight> int kPadWidth, int kPadHeight>
struct DepthwiseConvPartial {}; struct DepthwiseConvPartial {};
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
int kPadWidth, int kPadHeight>
struct DepthwiseConvPartialPerChannel {};
// Copies a subset of the input designated by |input_ptr| into |output_ptr| // Copies a subset of the input designated by |input_ptr| into |output_ptr|
// with the specified output dimensions. Supports output depths of 64 only as // with the specified output dimensions. Supports output depths of 64 only as
// this is the cache line size. // this is the cache line size.
@ -367,12 +375,19 @@ struct ShuffleParams {
input_height(get_shuffle_input_size(stride_height, output_height)) {} input_height(get_shuffle_input_size(stride_height, output_height)) {}
}; };
enum class QuantizationType {
kNonPerChannelUint8 = 0,
kPerChannelInt8 = 1,
};
template <
QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
inline bool Fast3x3FilterKernelSupported( inline bool Fast3x3FilterKernelSupported(
const RuntimeShape& input_shape, const RuntimeShape& filter_shape, const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
int32 stride_width, int32 stride_height, int32 dilation_width_factor, int32 stride_width, int32 stride_height, int32 dilation_width_factor,
int32 dilation_height_factor, int32 pad_width, int32 pad_height, int32 dilation_height_factor, int32 pad_width, int32 pad_height,
int32 depth_multiplier, const RuntimeShape& output_shape, int32 depth_multiplier, const RuntimeShape& output_shape,
int32 output_shift) { int32 output_shift, const int32* output_shift_ptr = nullptr) {
const int32 input_height = input_shape.Dims(1); const int32 input_height = input_shape.Dims(1);
const int32 input_width = input_shape.Dims(2); const int32 input_width = input_shape.Dims(2);
const int32 input_depth = input_shape.Dims(3); const int32 input_depth = input_shape.Dims(3);
@ -380,6 +395,7 @@ inline bool Fast3x3FilterKernelSupported(
const int32 filter_width = filter_shape.Dims(2); const int32 filter_width = filter_shape.Dims(2);
const int32 output_height = output_shape.Dims(1); const int32 output_height = output_shape.Dims(1);
const int32 output_width = output_shape.Dims(2); const int32 output_width = output_shape.Dims(2);
const int32 output_depth = output_shape.Dims(3);
bool supported = bool supported =
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
@ -394,6 +410,14 @@ inline bool Fast3x3FilterKernelSupported(
return false; return false;
} }
if (quantization_type == QuantizationType::kPerChannelInt8) {
for (int i = 0; i < output_depth; ++i) {
if (output_shift_ptr[i] <= 0) {
return false;
}
}
}
// Handle case where padding is zero but padding type is not kValid. // Handle case where padding is zero but padding type is not kValid.
// This would require special boundary case handling that is not supported. // This would require special boundary case handling that is not supported.

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h"
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"