Fix depthwise_conv padding case

PiperOrigin-RevId: 249798490
This commit is contained in:
Renjie Liu 2019-05-24 02:17:12 -07:00 committed by TensorFlower Gardener
parent 353e06c3a3
commit 5d51c3a37c

View File

@ -1936,8 +1936,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"saddw v0.8h, v25.8h, v0.8b\n"
// Loads output_multiplier & output_shift.
"ld1 {v6.4s, v7.4s}, [%[output_multiplier_ptr]] \n"
"ld1 {v10.4s, v11.4s}, [%[output_shift_ptr]] \n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v10.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v11.4s}, [%[output_shift_ptr]], #16\n"
"blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
@ -1965,6 +1967,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"ld1 {v16.4s}, [%[bias_ptr]], #16\n"
"saddw v0.8h, v25.8h, v0.8b\n"
"ld1 {v17.4s}, [%[bias_ptr]], #16\n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v10.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v11.4s}, [%[output_shift_ptr]], #16\n"
"bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
@ -1987,11 +1993,11 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
:
// Outputs.
[filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr),
[output_multiplier_ptr] "+r"(output_multiplier_ptr),
[output_shift_ptr] "+r"(output_shift_ptr)
:
// Inputs.
[output_multiplier_ptr] "r"(output_multiplier_ptr),
[output_shift_ptr] "r"(output_shift_ptr),
[params_ptr] "r"(params_ptr)
:
// Clobbers.
@ -2058,8 +2064,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"dup v25.8h, w6\n"
// Loads output_multiplier & output_shift.
"ld1 {v4.4s, v5.4s}, [%[output_multiplier_ptr]] \n"
"ld1 {v6.4s, v7.4s}, [%[output_shift_ptr]] \n"
"ld1 {v4.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v6.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v5.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_shift_ptr]], #16\n"
// Add input and filter offsets.
"saddw v8.8h, v26.8h, v8.8b\n"
@ -2118,6 +2126,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"saddw v1.8h, v25.8h, v1.8b\n"
"saddw v2.8h, v25.8h, v2.8b\n"
"saddw v3.8h, v25.8h, v3.8b\n"
"ld1 {v4.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v6.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v5.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_shift_ptr]], #16\n"
"bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
@ -2146,11 +2158,11 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
:
// Outputs.
[filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr),
[output_multiplier_ptr] "+r"(output_multiplier_ptr),
[output_shift_ptr] "+r"(output_shift_ptr)
:
// Inputs.
[output_multiplier_ptr] "r"(output_multiplier_ptr),
[output_shift_ptr] "r"(output_shift_ptr),
[params_ptr] "r"(params_ptr)
:
// Clobbers.
@ -2223,8 +2235,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"dup v25.8h, w12\n"
// Loads output_multiplier & output_shift.
"ld1 {v6.4s, v7.4s}, [%[output_multiplier_ptr]] \n"
"ld1 {v14.4s, v15.4s}, [%[output_shift_ptr]] \n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v14.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v15.4s}, [%[output_shift_ptr]], #16\n"
// Add input and filter offsets.
"saddw v8.8h, v26.8h, v8.8b\n"
@ -2307,6 +2321,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"ld1 {v17.4s}, [%[bias_ptr]], #16\n"
"saddw v4.8h, v25.8h, v4.8b\n"
"saddw v5.8h, v25.8h, v5.8b\n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v14.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v15.4s}, [%[output_shift_ptr]], #16\n"
"bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
@ -2338,11 +2356,11 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
:
// Outputs.
[filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr),
[output_multiplier_ptr] "+r"(output_multiplier_ptr),
[output_shift_ptr] "+r"(output_shift_ptr)
:
// Inputs.
[output_multiplier_ptr] "r"(output_multiplier_ptr),
[output_shift_ptr] "r"(output_shift_ptr),
[params_ptr] "r"(params_ptr)
:
// Clobbers.
@ -2417,8 +2435,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"dup v25.8h, w12\n"
// Loads output_multiplier & output_shift.
"ld1 {v6.4s, v7.4s}, [%[output_multiplier_ptr]] \n"
"ld1 {v14.4s, v15.4s}, [%[output_shift_ptr]] \n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v14.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v15.4s}, [%[output_shift_ptr]], #16\n"
// Add input and filter offsets.
"saddw v8.8h, v26.8h, v8.8b\n"
@ -2503,6 +2523,10 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
"ld1 {v17.4s}, [%[bias_ptr]], #16\n"
"saddw v4.8h, v25.8h, v4.8b\n"
"saddw v5.8h, v25.8h, v5.8b\n"
"ld1 {v6.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v14.4s}, [%[output_shift_ptr]], #16\n"
"ld1 {v7.4s}, [%[output_multiplier_ptr]], #16\n"
"ld1 {v15.4s}, [%[output_shift_ptr]], #16\n"
"bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
@ -2535,11 +2559,11 @@ struct DepthwiseConvPartialPerChannel<DepthwiseConvOutputRounding::kUpward,
:
// Outputs.
[filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
[output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr),
[output_multiplier_ptr] "+r"(output_multiplier_ptr),
[output_shift_ptr] "+r"(output_shift_ptr)
:
// Inputs.
[output_multiplier_ptr] "r"(output_multiplier_ptr),
[output_shift_ptr] "r"(output_shift_ptr),
[params_ptr] "r"(params_ptr)
:
// Clobbers.
@ -2599,6 +2623,8 @@ struct DepthwiseConvThroughDepthPerChannel {
output_ptr += 8;
filter_ptr += 8;
bias_ptr += 8;
output_multiplier_ptr += 8;
output_shift_ptr += 8;
}
}
};
@ -2610,8 +2636,8 @@ struct DepthwiseConvMultiRowPerChannel {
DepthwiseConvThroughDepthPerChannel<output_rounding, kStrideWidth,
kStrideHeight>;
static inline void Run(const int32* output_multiplier_ptr,
const int32* output_shift_ptr, const int8* input_data,
static inline void Run(const int32* output_multiplier,
const int32* output_shift, const int8* input_data,
int32 start_x, int32 end_x, const int8* filter_data,
const int32* bias_data, int8* output_data,
const DepthwiseConvParams& params,
@ -2639,6 +2665,8 @@ struct DepthwiseConvMultiRowPerChannel {
out_x += shuffle_params.output_width) {
const int8* input_ptr = input_data;
const int32* bias_ptr = bias_data;
const int32* output_multiplier_ptr = output_multiplier;
const int32* output_shift_ptr = output_shift;
const int8* filter_ptr = filter_data;
int8* output_ptr = output_data;
int64_t depth = 0;
@ -2669,6 +2697,8 @@ struct DepthwiseConvMultiRowPerChannel {
output_ptr += 64;
filter_ptr += 64;
bias_ptr += 64;
output_multiplier_ptr += 64;
output_shift_ptr += 64;
}
// Preload.
@ -2697,11 +2727,11 @@ struct DepthwiseConvMultiRowPerChannel {
const int32 output_leftover_width = end_x - out_x;
if (output_leftover_width > 0) {
ConvKernel::Run(output_multiplier_ptr, output_shift_ptr, input_data,
filter_data, bias_data, output_data, 0,
params.output_depth, params.input_depth,
params.input_row_size, shuffle_params.output_height,
output_leftover_width, params);
ConvKernel::Run(output_multiplier, output_shift, input_data, filter_data,
bias_data, output_data, 0, params.output_depth,
params.input_depth, params.input_row_size,
shuffle_params.output_height, output_leftover_width,
params);
}
}
};