Copy depthwise_conv uint8 3x3 filter implementation
PiperOrigin-RevId: 247210491
This commit is contained in:
parent
61c8837163
commit
cc0ad492d2
@ -180,6 +180,7 @@ cc_library(
|
|||||||
"optimized/integer_ops/add.h",
|
"optimized/integer_ops/add.h",
|
||||||
"optimized/integer_ops/conv.h",
|
"optimized/integer_ops/conv.h",
|
||||||
"optimized/integer_ops/depthwise_conv.h",
|
"optimized/integer_ops/depthwise_conv.h",
|
||||||
|
"optimized/integer_ops/depthwise_conv_3x3_filter.h",
|
||||||
"optimized/integer_ops/fully_connected.h",
|
"optimized/integer_ops/fully_connected.h",
|
||||||
"optimized/integer_ops/mul.h",
|
"optimized/integer_ops/mul.h",
|
||||||
"optimized/integer_ops/pooling.h",
|
"optimized/integer_ops/pooling.h",
|
||||||
|
@ -319,12 +319,20 @@ template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
|
|||||||
int32 kStrideWidth, int32 kStrideHeight>
|
int32 kStrideWidth, int32 kStrideHeight>
|
||||||
struct DepthwiseConvWindow {};
|
struct DepthwiseConvWindow {};
|
||||||
|
|
||||||
|
template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
|
||||||
|
int32 kStrideWidth, int32 kStrideHeight>
|
||||||
|
struct DepthwiseConvWindowPerChannel {};
|
||||||
|
|
||||||
enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
|
enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
|
||||||
|
|
||||||
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
|
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
|
||||||
int kPadWidth, int kPadHeight>
|
int kPadWidth, int kPadHeight>
|
||||||
struct DepthwiseConvPartial {};
|
struct DepthwiseConvPartial {};
|
||||||
|
|
||||||
|
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
|
||||||
|
int kPadWidth, int kPadHeight>
|
||||||
|
struct DepthwiseConvPartialPerChannel {};
|
||||||
|
|
||||||
// Copies a subset of the input designated by |input_ptr| into |output_ptr|
|
// Copies a subset of the input designated by |input_ptr| into |output_ptr|
|
||||||
// with the specified output dimensions. Supports output depths of 64 only as
|
// with the specified output dimensions. Supports output depths of 64 only as
|
||||||
// this is the cache line size.
|
// this is the cache line size.
|
||||||
@ -367,12 +375,19 @@ struct ShuffleParams {
|
|||||||
input_height(get_shuffle_input_size(stride_height, output_height)) {}
|
input_height(get_shuffle_input_size(stride_height, output_height)) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class QuantizationType {
|
||||||
|
kNonPerChannelUint8 = 0,
|
||||||
|
kPerChannelInt8 = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
|
||||||
inline bool Fast3x3FilterKernelSupported(
|
inline bool Fast3x3FilterKernelSupported(
|
||||||
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
|
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
|
||||||
int32 stride_width, int32 stride_height, int32 dilation_width_factor,
|
int32 stride_width, int32 stride_height, int32 dilation_width_factor,
|
||||||
int32 dilation_height_factor, int32 pad_width, int32 pad_height,
|
int32 dilation_height_factor, int32 pad_width, int32 pad_height,
|
||||||
int32 depth_multiplier, const RuntimeShape& output_shape,
|
int32 depth_multiplier, const RuntimeShape& output_shape,
|
||||||
int32 output_shift) {
|
int32 output_shift, const int32* output_shift_ptr = nullptr) {
|
||||||
const int32 input_height = input_shape.Dims(1);
|
const int32 input_height = input_shape.Dims(1);
|
||||||
const int32 input_width = input_shape.Dims(2);
|
const int32 input_width = input_shape.Dims(2);
|
||||||
const int32 input_depth = input_shape.Dims(3);
|
const int32 input_depth = input_shape.Dims(3);
|
||||||
@ -380,6 +395,7 @@ inline bool Fast3x3FilterKernelSupported(
|
|||||||
const int32 filter_width = filter_shape.Dims(2);
|
const int32 filter_width = filter_shape.Dims(2);
|
||||||
const int32 output_height = output_shape.Dims(1);
|
const int32 output_height = output_shape.Dims(1);
|
||||||
const int32 output_width = output_shape.Dims(2);
|
const int32 output_width = output_shape.Dims(2);
|
||||||
|
const int32 output_depth = output_shape.Dims(3);
|
||||||
|
|
||||||
bool supported =
|
bool supported =
|
||||||
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
|
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
|
||||||
@ -394,6 +410,14 @@ inline bool Fast3x3FilterKernelSupported(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (quantization_type == QuantizationType::kPerChannelInt8) {
|
||||||
|
for (int i = 0; i < output_depth; ++i) {
|
||||||
|
if (output_shift_ptr[i] <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle case where padding is zero but padding type is not kValid.
|
// Handle case where padding is zero but padding type is not kValid.
|
||||||
// This would require special boundary case handling that is not supported.
|
// This would require special boundary case handling that is not supported.
|
||||||
|
|
||||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
|
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user