Avoid im2col creation for multi-threaded conv
PiperOrigin-RevId: 241424360
This commit is contained in:
parent
d424ed5fdc
commit
7362d37521
@ -68,6 +68,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "eigen_support_test",
|
||||
size = "small",
|
||||
srcs = ["eigen_support_test.cc"],
|
||||
deps = [
|
||||
":eigen_support",
|
||||
"//tensorflow/lite/kernels/internal:optimized",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemm_support",
|
||||
srcs = [
|
||||
|
@ -90,7 +90,7 @@ struct OpData {
|
||||
bool have_weights_been_transposed;
|
||||
bool need_im2col;
|
||||
|
||||
bool run_multithreaded_kernel;
|
||||
bool supports_multithreaded_kernel;
|
||||
};
|
||||
|
||||
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
|
||||
@ -153,14 +153,6 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
int filter_width = filter->dims->data[2];
|
||||
int filter_height = filter->dims->data[1];
|
||||
|
||||
// We don't always need to allocate im2col. It is only used in some versions
|
||||
// of the optimized Conv. This test just mimics something that happens inside
|
||||
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
|
||||
data->need_im2col =
|
||||
(params->stride_width != 1 || params->stride_height != 1 ||
|
||||
params->dilation_width_factor != 1 ||
|
||||
params->dilation_height_factor != 1 || filter_width != 1 ||
|
||||
filter_height != 1);
|
||||
// If we're using the optimized multithreaded EigenTensor implementation of
|
||||
// convolution, it expects the filter weights to be transposed compared to
|
||||
// the normal TF Lite buffer format. Typical TF Lite weights are
|
||||
@ -171,7 +163,17 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
// This path is only used for float processing, so only create the buffer if
|
||||
// we're running with that data type.
|
||||
data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
|
||||
data->run_multithreaded_kernel && !is_hybrid);
|
||||
data->supports_multithreaded_kernel && !is_hybrid);
|
||||
|
||||
// We don't always need to allocate im2col. It is only used in some versions
|
||||
// of the optimized Conv. This test just mimics something that happens inside
|
||||
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
|
||||
data->need_im2col =
|
||||
!data->need_hwcn_weights &&
|
||||
(params->stride_width != 1 || params->stride_height != 1 ||
|
||||
params->dilation_width_factor != 1 ||
|
||||
params->dilation_height_factor != 1 || filter_width != 1 ||
|
||||
filter_height != 1);
|
||||
|
||||
int temporaries_count = 0;
|
||||
if (data->need_im2col) {
|
||||
@ -214,7 +216,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
@ -260,11 +263,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
(input->type == kTfLiteFloat32 &&
|
||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
|
||||
|
||||
data->run_multithreaded_kernel = context->recommended_num_threads != 1;
|
||||
// Hybrid kernels don't support multithreading yet.
|
||||
if (is_hybrid) {
|
||||
data->run_multithreaded_kernel = false;
|
||||
}
|
||||
// The multi-threaded kernel supports neither dilation nor hybrid kernels.
|
||||
data->supports_multithreaded_kernel =
|
||||
(kernel_type == kMultithreadOptimized) &&
|
||||
(context->recommended_num_threads != 1) && !is_hybrid &&
|
||||
(params->dilation_width_factor == 1) &&
|
||||
(params->dilation_height_factor == 1);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
AllocateTemporaryTensorsIfRequired(context, node, is_hybrid));
|
||||
@ -418,6 +422,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Prepare(kernel_type, context, node);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
|
||||
@ -547,18 +556,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
KernelType effective_kernel_type = kernel_type;
|
||||
if (kernel_type == kMultithreadOptimized) {
|
||||
if (context->recommended_num_threads == 1) {
|
||||
// Use of kMultithreadOptimized is precomputed during |Prepare()|, whereas
|
||||
// the actual thread count can change at any time. If the client requests
|
||||
// a single thread (after Prepare()), fall back to optimized.
|
||||
effective_kernel_type = kGenericOptimized;
|
||||
} else if ((params->dilation_width_factor != 1) ||
|
||||
(params->dilation_height_factor != 1)) {
|
||||
// kMultithreadOptimized does not support dilation.
|
||||
// Therefore, fallback to optimized.
|
||||
effective_kernel_type = kGenericOptimized;
|
||||
}
|
||||
// Fall back to the optimized path if multi-threaded conv is unsupported.
|
||||
if ((kernel_type == kMultithreadOptimized) &&
|
||||
!data->supports_multithreaded_kernel) {
|
||||
effective_kernel_type = kGenericOptimized;
|
||||
}
|
||||
ConvParams op_params;
|
||||
op_params.padding_type = RuntimePaddingType(params->padding);
|
||||
@ -714,7 +715,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
||||
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
||||
bias, im2col, hwcn_weights, output);
|
||||
} else if (data->run_multithreaded_kernel) {
|
||||
} else if (data->supports_multithreaded_kernel) {
|
||||
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
||||
im2col, hwcn_weights, output);
|
||||
} else {
|
||||
@ -741,25 +742,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace conv
|
||||
|
||||
TfLiteRegistration* Register_CONVOLUTION_REF() {
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||
conv::Prepare<conv::kReference>,
|
||||
conv::Eval<conv::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||
conv::Prepare<conv::kGenericOptimized>,
|
||||
conv::Eval<conv::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() {
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||
conv::Prepare<conv::kMultithreadOptimized>,
|
||||
conv::Eval<conv::kMultithreadOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() {
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
||||
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||
conv::Prepare<conv::kCblasOptimized>,
|
||||
conv::Eval<conv::kCblasOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
@ -50,20 +50,35 @@ void SetEigenNbThreads(int threads) {
|
||||
// We have a single global threadpool for all convolution operations. This means
|
||||
// that inferences started from different threads may block each other, but
|
||||
// since the underlying resource of CPU cores should be consumed by the
|
||||
// operations anyway, it shouldn't affect overall performance.
|
||||
// operations anyway, it shouldn't affect overall performance. Note that we
|
||||
// also avoid ThreadPool creation if the target thread count is 1, avoiding
|
||||
// unnecessary overhead, and more closely mimicking Gemmlowp threadpool
|
||||
// behavior.
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
// Takes ownership of 'pool'
|
||||
explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
|
||||
explicit EigenThreadPoolWrapper(int num_threads) {
|
||||
// Avoid creating any threads for the single-threaded case.
|
||||
if (num_threads > 1) {
|
||||
pool_.reset(new Eigen::ThreadPool(num_threads));
|
||||
}
|
||||
}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
pool_->Schedule(std::move(fn));
|
||||
if (pool_) {
|
||||
pool_->Schedule(std::move(fn));
|
||||
} else {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; }
|
||||
int CurrentThreadId() const override {
|
||||
return pool_ ? pool_->CurrentThreadId() : 0;
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
// May be null if num_threads <= 1.
|
||||
std::unique_ptr<Eigen::ThreadPool> pool_;
|
||||
};
|
||||
|
||||
@ -77,8 +92,8 @@ class LazyEigenThreadPoolHolder {
|
||||
// Gets the ThreadPoolDevice, creating if necessary.
|
||||
const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
|
||||
if (!device_) {
|
||||
thread_pool_wrapper_.reset(new EigenThreadPoolWrapper(
|
||||
new Eigen::ThreadPool(target_num_threads_)));
|
||||
thread_pool_wrapper_.reset(
|
||||
new EigenThreadPoolWrapper(target_num_threads_));
|
||||
device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
|
||||
target_num_threads_));
|
||||
}
|
||||
|
145
tensorflow/lite/kernels/eigen_support_test.cc
Normal file
145
tensorflow/lite/kernels/eigen_support_test.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace eigen_support {
|
||||
|
||||
struct TestTfLiteContext : public TfLiteContext {
|
||||
TestTfLiteContext() {
|
||||
recommended_num_threads = -1;
|
||||
external_context = nullptr;
|
||||
GetExternalContext = GetExternalContextImpl;
|
||||
SetExternalContext = SetExternalContextImpl;
|
||||
}
|
||||
|
||||
static void SetExternalContextImpl(TfLiteContext* context,
|
||||
TfLiteExternalContextType type,
|
||||
TfLiteExternalContext* external_context) {
|
||||
static_cast<TestTfLiteContext*>(context)->external_context =
|
||||
external_context;
|
||||
}
|
||||
|
||||
static TfLiteExternalContext* GetExternalContextImpl(
|
||||
TfLiteContext* context, TfLiteExternalContextType type) {
|
||||
return static_cast<TestTfLiteContext*>(context)->external_context;
|
||||
}
|
||||
|
||||
TfLiteExternalContext* external_context;
|
||||
};
|
||||
|
||||
TEST(EigenSupport, Default) {
|
||||
TestTfLiteContext context;
|
||||
IncrementUsageCounter(&context);
|
||||
ASSERT_NE(context.external_context, nullptr);
|
||||
EXPECT_EQ(context.external_context->type, kTfLiteEigenContext);
|
||||
|
||||
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||
ASSERT_NE(thread_pool_device, nullptr);
|
||||
EXPECT_EQ(thread_pool_device->numThreads(), 4);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
}
|
||||
|
||||
TEST(EigenSupport, SingleThreaded) {
|
||||
TestTfLiteContext context;
|
||||
context.recommended_num_threads = 1;
|
||||
IncrementUsageCounter(&context);
|
||||
|
||||
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||
ASSERT_NE(thread_pool_device, nullptr);
|
||||
EXPECT_EQ(thread_pool_device->numThreads(), 1);
|
||||
EXPECT_EQ(thread_pool_device->numThreadsInPool(), 1);
|
||||
|
||||
bool executed = false;
|
||||
auto notification =
|
||||
thread_pool_device->enqueue([&executed]() { executed = true; });
|
||||
ASSERT_NE(notification, nullptr);
|
||||
notification->Wait();
|
||||
delete notification;
|
||||
EXPECT_TRUE(executed);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
}
|
||||
|
||||
TEST(EigenSupport, MultiThreaded) {
|
||||
TestTfLiteContext context;
|
||||
context.recommended_num_threads = 2;
|
||||
IncrementUsageCounter(&context);
|
||||
|
||||
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||
ASSERT_NE(thread_pool_device, nullptr);
|
||||
EXPECT_EQ(thread_pool_device->numThreads(), 2);
|
||||
|
||||
bool executed = false;
|
||||
auto notification =
|
||||
thread_pool_device->enqueue([&executed]() { executed = true; });
|
||||
ASSERT_NE(notification, nullptr);
|
||||
notification->Wait();
|
||||
delete notification;
|
||||
EXPECT_TRUE(executed);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
}
|
||||
|
||||
TEST(EigenSupport, NumThreadsChanged) {
|
||||
TestTfLiteContext context;
|
||||
context.recommended_num_threads = 1;
|
||||
IncrementUsageCounter(&context);
|
||||
|
||||
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||
ASSERT_NE(thread_pool_device, nullptr);
|
||||
EXPECT_EQ(thread_pool_device->numThreads(), 1);
|
||||
|
||||
context.recommended_num_threads = 3;
|
||||
ASSERT_NE(context.external_context, nullptr);
|
||||
context.external_context->Refresh(&context);
|
||||
thread_pool_device = GetThreadPoolDevice(&context);
|
||||
ASSERT_NE(thread_pool_device, nullptr);
|
||||
EXPECT_EQ(thread_pool_device->numThreads(), 3);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
}
|
||||
|
||||
TEST(EigenSupport, RefCounting) {
|
||||
TestTfLiteContext context;
|
||||
EXPECT_EQ(context.external_context, nullptr);
|
||||
|
||||
IncrementUsageCounter(&context);
|
||||
EXPECT_NE(context.external_context, nullptr);
|
||||
|
||||
IncrementUsageCounter(&context);
|
||||
EXPECT_NE(context.external_context, nullptr);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
EXPECT_NE(context.external_context, nullptr);
|
||||
|
||||
DecrementUsageCounter(&context);
|
||||
EXPECT_EQ(context.external_context, nullptr);
|
||||
}
|
||||
|
||||
} // namespace eigen_support
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -85,12 +85,12 @@ class EigenTensorConvFunctor {
|
||||
|
||||
public:
|
||||
void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
|
||||
T* im2col_buffer, int input_batches, int input_height,
|
||||
int input_width, int input_depth, const T* filter_data,
|
||||
int filter_height, int filter_width, int filter_count,
|
||||
int stride_rows, int stride_cols, int pad_width,
|
||||
int pad_height, PaddingType padding, T* output_data,
|
||||
int output_height, int output_width) {
|
||||
int input_batches, int input_height, int input_width,
|
||||
int input_depth, const T* filter_data, int filter_height,
|
||||
int filter_width, int filter_count, int stride_rows,
|
||||
int stride_cols, int pad_width, int pad_height,
|
||||
PaddingType padding, T* output_data, int output_height,
|
||||
int output_width) {
|
||||
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
|
||||
stride_rows == 1 && stride_cols == 1);
|
||||
if (is_1x1_kernel) {
|
||||
@ -139,6 +139,9 @@ inline void Conv(const Eigen::ThreadPoolDevice& device,
|
||||
const float* bias_data, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& im2col_shape,
|
||||
float* im2col_data) {
|
||||
// im2col data should not be generated for the multi-thread supporting case.
|
||||
TFLITE_DCHECK(!im2col_data);
|
||||
(void)im2col_shape;
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const PaddingType padding = params.padding_type;
|
||||
@ -160,11 +163,10 @@ inline void Conv(const Eigen::ThreadPoolDevice& device,
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
EigenTensorConvFunctor<float> conv_functor;
|
||||
conv_functor(device, input_data, im2col_data, batches, input_height,
|
||||
input_width, input_depth, filter_data, filter_height,
|
||||
filter_width, output_depth, stride_height, stride_width,
|
||||
pad_height, pad_width, padding, output_data, output_height,
|
||||
output_width);
|
||||
conv_functor(device, input_data, batches, input_height, input_width,
|
||||
input_depth, filter_data, filter_height, filter_width,
|
||||
output_depth, stride_height, stride_width, pad_height, pad_width,
|
||||
padding, output_data, output_height, output_width);
|
||||
|
||||
optimized_ops::AddBiasAndEvalActivationFunction(
|
||||
output_activation_min, output_activation_max, bias_shape, bias_data,
|
||||
|
Loading…
Reference in New Issue
Block a user