From 152095e319a0b6b79a13e689ea8decee819dbfcd Mon Sep 17 00:00:00 2001 From: Lu Wang Date: Sat, 23 Mar 2019 10:31:22 -0700 Subject: [PATCH] Add gemmlowp-threadpool multithreading to the depthwiseconv implementation for the quantized path. PiperOrigin-RevId: 239959051 --- tensorflow/lite/kernels/depthwise_conv.cc | 33 ++-- .../lite/kernels/depthwise_conv_test.cc | 167 ++++++++++++++++++ .../internal/depthwiseconv_quantized_test.cc | 12 +- .../internal/optimized/depthwiseconv_uint8.h | 162 +++++++++++++++-- .../depthwiseconv_uint8_3x3_filter.h | 36 +++- .../internal/optimized/legacy_optimized_ops.h | 11 +- tensorflow/lite/kernels/test_util.h | 1 + 7 files changed, 380 insertions(+), 42 deletions(-) diff --git a/tensorflow/lite/kernels/depthwise_conv.cc b/tensorflow/lite/kernels/depthwise_conv.cc index a349b279053..e29969f6488 100644 --- a/tensorflow/lite/kernels/depthwise_conv.cc +++ b/tensorflow/lite/kernels/depthwise_conv.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -66,6 +67,7 @@ struct OpData { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { + gemm_support::IncrementUsageCounter(context); // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to // Eval(). @@ -73,6 +75,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); } @@ -230,17 +233,6 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&, - const uint8*, const RuntimeShape&, const uint8*, - const RuntimeShape&, const int32*, const RuntimeShape&, - uint8*); - - if (kernel_type == kReference) { - depthwise_conv = &reference_ops::DepthwiseConv; - } else { - depthwise_conv = &optimized_ops::DepthwiseConv; - } - DepthwiseParams op_params; op_params.padding_type = PaddingType::kSame; op_params.padding_values.width = data->padding.width; @@ -257,11 +249,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params.output_shift = -data->output_shift; op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; - depthwise_conv(op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); + if (kernel_type == kReference) { + reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } else { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + optimized_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), gemm_context); + } } void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, diff --git a/tensorflow/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc index 2413a95b5d6..8394b3bb573 100644 --- a/tensorflow/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/kernels/depthwise_conv_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/memory/memory.h" #include "tensorflow/lite/interpreter.h" @@ -501,6 +502,172 @@ TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) { ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1})); } +TEST_P(DepthwiseConvolutionOpTest, MultithreadOnRowUint8GeneralTest) { + const int depth = 1; + const int image_width = 4; + const int image_height = 28; + const int image_batch_count = 3; + const int filter_size = 3; + const int filter_count = 1; + + QuantizedDepthwiseConvolutionOpModel m( + GetRegistration(), + {TensorType_UINT8, + {image_batch_count, image_height, image_width, depth}, + 0, + 255}, + {TensorType_UINT8, + {depth, filter_size, filter_size, filter_count}, + 0, + 255}, + {TensorType_UINT8, {}, 0, 255}, Padding_VALID); + + // clang-format off + m.SetInput({ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }); + // clang-format on + + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.SetNumThreads(4); + m.Invoke(); + + // clang-format off + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 24, 24, 39, 39, + 45, 45, 45, 45, 45, 45, 45, 45, + 45, 45, 45, 45, 45, 45, 45, 45, + 45, 45, 45, 45, 21, 21, 6, 6, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 48, 48, 78, 78, + 90, 90, 90, 90, 90, 90, 90, 90, + 90, 90, 90, 90, 90, 90, 90, 90, + 90, 90, 90, 90, 42, 42, 12, 12, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 72, 72, 117, 117, + 135, 135, 135, 135, 135, 135, 135, 135, + 135, 135, 135, 135, 135, 135, 135, 135, + 135, 135, 135, 135, 63, 63, 18, 18, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + })); + // clang-format on +} + +TEST_P(DepthwiseConvolutionOpTest, MultithreadOnBatchUint8GeneralTest) { + const int depth = 1; + const int image_width = 8; + const int image_height = 4; + const int image_batch_count = 6; + const int filter_size = 3; + const int filter_count = 1; + + QuantizedDepthwiseConvolutionOpModel m( + GetRegistration(), + {TensorType_UINT8, + {image_batch_count, image_height, image_width, depth}, + 0, + 255}, + {TensorType_UINT8, + {depth, filter_size, filter_size, filter_count}, + 0, + 255}, + {TensorType_UINT8, {}, 0, 255}, Padding_VALID); + + // clang-format off + m.SetInput({ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 + }); + // clang-format on + + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.SetNumThreads(4); + m.Invoke(); + + // clang-format off + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({ + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21, + + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21, + + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21, + + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21, + + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21, + + 39, 39, 39, 39, 39, 39, + 21, 21, 21, 21, 21, 21 + })); + // clang-format on +} + class PerChannelQuantizedDepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { public: diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index 3e48d95a082..6cb1eba9822 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -139,7 +139,8 @@ inline void DispatchDepthwiseConv( // Call kernel optimized for depthwise convolutions using 3x3 filters. optimized_ops::depthwise_conv::DepthwiseConv3x3Filter( params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); + bias_shape, bias_data, output_shape, output_data, /*thread_start=*/0, + /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1); return; #else break; @@ -242,7 +243,8 @@ inline void DispatchDepthwiseConv( case DepthwiseConvImplementation::kUseGenericKernel: { optimized_ops::depthwise_conv::DepthwiseConvGeneral( params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); + bias_shape, bias_data, output_shape, output_data, /*thread_start=*/0, + /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1); return; } case DepthwiseConvImplementation::kNone: @@ -271,13 +273,15 @@ inline void DispatchDepthwiseConv( optimized_ops::DepthwiseConvWithRounding< DepthwiseConvOutputRounding::kAwayFromZero>( params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); + bias_shape, bias_data, output_shape, output_data, /*thread_start=*/0, + /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1); return; case DepthwiseConvOutputRounding::kUpward: optimized_ops::DepthwiseConvWithRounding< DepthwiseConvOutputRounding::kUpward>( params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); + bias_shape, bias_data, output_shape, output_data, /*thread_start=*/0, + /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1); return; default: break; diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index d1a9d65aae8..44e0a26f5ae 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1662,7 +1662,7 @@ inline void DepthwiseConvGeneral( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8* output_data, int thread_start, int thread_end, int thread_dim) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = params.padding_values.width; @@ -1684,7 +1684,7 @@ inline void DepthwiseConvGeneral( const int input_depth = input_shape.Dims(3); const int filter_height = filter_shape.Dims(1); const int filter_width = filter_shape.Dims(2); - const int output_height = output_shape.Dims(1); + const int output_rows = output_shape.Dims(1); const int output_width = output_shape.Dims(2); #ifdef USE_NEON const bool shift_left = (output_shift > 0); @@ -1700,6 +1700,7 @@ inline void DepthwiseConvGeneral( kAccBufferActualSize); TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1); // row_accum_func will point to the core accumulation function to be used // for this DepthwiseConv op. @@ -1766,9 +1767,34 @@ inline void DepthwiseConvGeneral( const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2); // Now that we have determined row_accum_func, we can start work. - uint8* output_ptr = output_data; - for (int b = 0; b < batches; ++b) { - for (int out_y = 0; out_y < output_height; ++out_y) { + int batch_start = 0; + int batch_end = batches; + int row_start = 0; + int row_end = output_rows; + int output_ptr_offset = 0; + + switch (thread_dim) { + case 0: + TFLITE_DCHECK_GE(thread_start, 0); + TFLITE_DCHECK_LE(thread_end, batches); + batch_start = thread_start; + batch_end = thread_end; + output_ptr_offset = batch_start * FlatSizeSkipDim(output_shape, 0); + break; + case 1: + TFLITE_DCHECK_GE(thread_start, 0); + TFLITE_DCHECK_LE(thread_end, output_rows); + row_start = thread_start; + row_end = thread_end; + output_ptr_offset = row_start * output_width * output_depth; + break; + } + + uint8* output_ptr = output_data + output_ptr_offset; + int batch_step = + (output_rows + row_start - row_end) * output_width * output_depth; + for (int b = batch_start; b < batch_end; ++b) { + for (int out_y = row_start; out_y < row_end; ++out_y) { const int in_y_origin = (out_y * stride_height) - pad_height; const int filter_y_start = std::max(0, (-in_y_origin + dilation_height_factor - 1) / @@ -1944,6 +1970,7 @@ inline void DepthwiseConvGeneral( } } } + output_ptr += batch_step; } } @@ -1955,7 +1982,7 @@ inline void DepthwiseConvWithRounding( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8* output_data, int thread_start, int thread_end, int thread_dim) { gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit"); const int depth_multiplier = params.depth_multiplier; const int32 output_activation_min = params.quantized_activation_min; @@ -1991,7 +2018,8 @@ inline void DepthwiseConvWithRounding( gemmlowp::ScopedProfilingLabel specialized_label("DepthwiseConv/8bit/3x3"); depthwise_conv::DepthwiseConv3x3Filter( params, input_shape, input_data, filter_shape, filter_data, bias_shape, - bias_data, output_shape, output_data); + bias_data, output_shape, output_data, thread_start, thread_end, + thread_dim); return; } #endif @@ -2000,7 +2028,77 @@ inline void DepthwiseConvWithRounding( "DepthwiseConv/8bit/General"); depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data, filter_shape, filter_data, bias_shape, - bias_data, output_shape, output_data); + bias_data, output_shape, output_data, + thread_start, thread_end, thread_dim); +} + +inline void DepthwiseConvImpl( + const DepthwiseParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, int thread_start, int thread_end, int thread_dim) { + return DepthwiseConvWithRounding( + params, input_shape, input_data, filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data, thread_start, thread_end, + thread_dim); +} + +template +struct DepthwiseConvWorkerTask : public gemmlowp::Task { + DepthwiseConvWorkerTask(const DepthwiseParams& params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& filter_shape, + const T* filter_data, const RuntimeShape& bias_shape, + const TS* bias_data, const RuntimeShape& output_shape, + T* output_data, int thread_start, int thread_end, + int thread_dim) + : params_(params), + input_shape_(input_shape), + input_data_(input_data), + filter_shape_(filter_shape), + filter_data_(filter_data), + bias_shape_(bias_shape), + bias_data_(bias_data), + output_shape_(output_shape), + output_data_(output_data), + thread_start_(thread_start), + thread_end_(thread_end), + thread_dim_(thread_dim) {} + + void Run() override { + DepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_, + filter_data_, bias_shape_, bias_data_, output_shape_, + output_data_, thread_start_, thread_end_, thread_dim_); + } + + private: + const DepthwiseParams& params_; + const RuntimeShape& input_shape_; + const T* input_data_; + const RuntimeShape& filter_shape_; + const T* filter_data_; + const RuntimeShape& bias_shape_; + const TS* bias_data_; + const RuntimeShape& output_shape_; + T* output_data_; + int thread_start_; + int thread_end_; + int thread_dim_; +}; + +inline int HowManyConvThreads(const RuntimeShape& output_shape, + const RuntimeShape& filter_shape, + int thread_dim) { + constexpr int kMinMulPerThread = 8; + const int output_units = output_shape.Dims(thread_dim); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int num_mul_per_unit = + FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width; + const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1; + int thread_count = output_units / min_units_per_thread; + return thread_count; } inline void DepthwiseConv( @@ -2008,10 +2106,50 @@ inline void DepthwiseConv( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data) { - return DepthwiseConvWithRounding( - params, input_shape, input_data, filter_shape, filter_data, bias_shape, - bias_data, output_shape, output_data); + uint8* output_data, gemmlowp::GemmContext* gemm_context = nullptr) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); + + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int output_batches = output_shape.Dims(0); + const int output_rows = output_shape.Dims(1); + int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0); + int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1); + int thread_dim, thread_count, thread_dim_size; + if (thread_count_batch > thread_count_row) { + thread_dim = 0; + thread_dim_size = output_batches; + thread_count = thread_count_batch; + } else { + thread_dim = 1; + thread_dim_size = output_rows; + thread_count = thread_count_row; + } + + const int max_threads = gemm_context ? gemm_context->max_num_threads() : 1; + thread_count = std::max(1, std::min(thread_count, max_threads)); + + if (thread_count == 1) { + DepthwiseConvImpl(params, input_shape, input_data, filter_shape, + filter_data, bias_shape, bias_data, output_shape, + output_data, /*thread_start=*/0, + /*thread_end=*/output_rows, /*thread_dim=*/1); + } else { + std::vector tasks(thread_count); + int thread_start = 0; + for (int i = 0; i < thread_count; ++i) { + int thread_end = + thread_start + (thread_dim_size - thread_start) / (thread_count - i); + tasks[i] = new DepthwiseConvWorkerTask( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data, thread_start, + thread_end, thread_dim); + thread_start = thread_end; + } + gemm_context->workers_pool()->Execute(tasks); + } } } // namespace optimized_ops diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 6bc4fb60325..3ed8b221882 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -3535,7 +3535,7 @@ inline void DepthwiseConv3x3Filter( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8* output_data, int thread_start, int thread_end, int thread_dim) { gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); DepthwiseConvParams params; @@ -3586,6 +3586,7 @@ inline void DepthwiseConv3x3Filter( TFLITE_DCHECK(pad_height == 0 || pad_height == 1); TFLITE_DCHECK(pad_width == 0 || pad_width == 1); TFLITE_DCHECK(pad_width == pad_height); + TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1); const int32 batches = MatchingDim(input_shape, 0, output_shape, 0); const int64_t input_batch_size = params.input_row_size * params.input_height; @@ -3619,14 +3620,35 @@ inline void DepthwiseConv3x3Filter( // used in gemmlowp. uint8 shuffle_workspace[kDepthwiseConvScratchWorkspaceSize]; - for (int32 b = 0; b < batches; ++b) { + int batch_start = 0; + int batch_end = batches; + int row_start = 0; + int row_end = params.output_height; + + switch (thread_dim) { + case 0: + TFLITE_DCHECK_GE(thread_start, 0); + TFLITE_DCHECK_LE(thread_end, batches); + batch_start = thread_start; + batch_end = thread_end; + break; + case 1: + TFLITE_DCHECK_GE(thread_start, 0); + TFLITE_DCHECK_LE(thread_end, params.output_height); + row_start = thread_start; + row_end = thread_end; + break; + } + + for (int32 b = batch_start; b < batch_end; ++b) { const uint8* input_ptr = input_data + b * input_batch_size; - uint8* output_ptr = output_data + b * output_batch_size; + uint8* output_ptr = output_data + b * output_batch_size + + row_start * params.output_width * params.output_depth; int32 out_x = 0; - int32 out_y = 0; + int32 out_y = row_start; int32 end_x = params.output_width; - int32 end_y = params.output_height; + int32 end_y = row_end; if (pad_width == 1 && pad_height == 1) { DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr, @@ -3635,8 +3657,8 @@ inline void DepthwiseConv3x3Filter( // Update extents now that the edges have been handled. out_x = 1; end_x = params.output_width - 1; - out_y = 1; - end_y = params.output_height - 1; + out_y = std::max(1, out_y); + end_y = std::min(params.output_height - 1, end_y); const int in_x = (out_x * stride_width) - pad_width; const int in_y = (out_y * stride_height) - pad_height; input_ptr += in_y * params.input_row_size + in_x * params.input_depth; diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index 5485d907c29..9589436c9fe 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -234,9 +234,14 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. op_params.output_shift = kDepthwiseReverseShift * output_shift; - DepthwiseConv(op_params, DimsToShape(input_dims), input_data, - DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), - bias_data, DimsToShape(output_dims), output_data); + const RuntimeShape output_shape = DimsToShape(output_dims); + const int output_height = output_shape.Dims(1); + + DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), filter_data, + DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), + output_data, /*thread_start=*/0, + /*thread_end=*/output_height, /*thread_dim=*/1); } inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 3a6b181daee..5f0a6550eb7 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -327,6 +327,7 @@ class SingleOpModel { } void SetNumThreads(int num_threads) { + CHECK(interpreter_ != nullptr); interpreter_->SetNumThreads(num_threads); }