Copy depthwise_conv uint8 3x3 filter implementation
PiperOrigin-RevId: 247210491
This commit is contained in:
parent
61c8837163
commit
cc0ad492d2
@ -180,6 +180,7 @@ cc_library(
|
||||
"optimized/integer_ops/add.h",
|
||||
"optimized/integer_ops/conv.h",
|
||||
"optimized/integer_ops/depthwise_conv.h",
|
||||
"optimized/integer_ops/depthwise_conv_3x3_filter.h",
|
||||
"optimized/integer_ops/fully_connected.h",
|
||||
"optimized/integer_ops/mul.h",
|
||||
"optimized/integer_ops/pooling.h",
|
||||
|
@ -319,12 +319,20 @@ template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
|
||||
int32 kStrideWidth, int32 kStrideHeight>
|
||||
struct DepthwiseConvWindow {};
|
||||
|
||||
template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
|
||||
int32 kStrideWidth, int32 kStrideHeight>
|
||||
struct DepthwiseConvWindowPerChannel {};
|
||||
|
||||
enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
|
||||
|
||||
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
|
||||
int kPadWidth, int kPadHeight>
|
||||
struct DepthwiseConvPartial {};
|
||||
|
||||
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
|
||||
int kPadWidth, int kPadHeight>
|
||||
struct DepthwiseConvPartialPerChannel {};
|
||||
|
||||
// Copies a subset of the input designated by |input_ptr| into |output_ptr|
|
||||
// with the specified output dimensions. Supports output depths of 64 only as
|
||||
// this is the cache line size.
|
||||
@ -367,12 +375,19 @@ struct ShuffleParams {
|
||||
input_height(get_shuffle_input_size(stride_height, output_height)) {}
|
||||
};
|
||||
|
||||
enum class QuantizationType {
|
||||
kNonPerChannelUint8 = 0,
|
||||
kPerChannelInt8 = 1,
|
||||
};
|
||||
|
||||
template <
|
||||
QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
|
||||
inline bool Fast3x3FilterKernelSupported(
|
||||
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
|
||||
int32 stride_width, int32 stride_height, int32 dilation_width_factor,
|
||||
int32 dilation_height_factor, int32 pad_width, int32 pad_height,
|
||||
int32 depth_multiplier, const RuntimeShape& output_shape,
|
||||
int32 output_shift) {
|
||||
int32 output_shift, const int32* output_shift_ptr = nullptr) {
|
||||
const int32 input_height = input_shape.Dims(1);
|
||||
const int32 input_width = input_shape.Dims(2);
|
||||
const int32 input_depth = input_shape.Dims(3);
|
||||
@ -380,6 +395,7 @@ inline bool Fast3x3FilterKernelSupported(
|
||||
const int32 filter_width = filter_shape.Dims(2);
|
||||
const int32 output_height = output_shape.Dims(1);
|
||||
const int32 output_width = output_shape.Dims(2);
|
||||
const int32 output_depth = output_shape.Dims(3);
|
||||
|
||||
bool supported =
|
||||
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
|
||||
@ -394,6 +410,14 @@ inline bool Fast3x3FilterKernelSupported(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (quantization_type == QuantizationType::kPerChannelInt8) {
|
||||
for (int i = 0; i < output_depth; ++i) {
|
||||
if (output_shift_ptr[i] <= 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle case where padding is zero but padding type is not kValid.
|
||||
// This would require special boundary case handling that is not supported.
|
||||
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user