Avoid im2col creation for multi-threaded conv
PiperOrigin-RevId: 241424360
This commit is contained in:
parent
d424ed5fdc
commit
7362d37521
@ -68,6 +68,17 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "eigen_support_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["eigen_support_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":eigen_support",
|
||||||
|
"//tensorflow/lite/kernels/internal:optimized",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemm_support",
|
name = "gemm_support",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -90,7 +90,7 @@ struct OpData {
|
|||||||
bool have_weights_been_transposed;
|
bool have_weights_been_transposed;
|
||||||
bool need_im2col;
|
bool need_im2col;
|
||||||
|
|
||||||
bool run_multithreaded_kernel;
|
bool supports_multithreaded_kernel;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
|
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
|
||||||
@ -153,14 +153,6 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
|||||||
int filter_width = filter->dims->data[2];
|
int filter_width = filter->dims->data[2];
|
||||||
int filter_height = filter->dims->data[1];
|
int filter_height = filter->dims->data[1];
|
||||||
|
|
||||||
// We don't always need to allocate im2col. It is only used in some versions
|
|
||||||
// of the optimized Conv. This test just mimics something that happens inside
|
|
||||||
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
|
|
||||||
data->need_im2col =
|
|
||||||
(params->stride_width != 1 || params->stride_height != 1 ||
|
|
||||||
params->dilation_width_factor != 1 ||
|
|
||||||
params->dilation_height_factor != 1 || filter_width != 1 ||
|
|
||||||
filter_height != 1);
|
|
||||||
// If we're using the optimized multithreaded EigenTensor implementation of
|
// If we're using the optimized multithreaded EigenTensor implementation of
|
||||||
// convolution, it expects the filter weights to be transposed compared to
|
// convolution, it expects the filter weights to be transposed compared to
|
||||||
// the normal TF Lite buffer format. Typical TF Lite weights are
|
// the normal TF Lite buffer format. Typical TF Lite weights are
|
||||||
@ -171,7 +163,17 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
|||||||
// This path is only used for float processing, so only create the buffer if
|
// This path is only used for float processing, so only create the buffer if
|
||||||
// we're running with that data type.
|
// we're running with that data type.
|
||||||
data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
|
data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
|
||||||
data->run_multithreaded_kernel && !is_hybrid);
|
data->supports_multithreaded_kernel && !is_hybrid);
|
||||||
|
|
||||||
|
// We don't always need to allocate im2col. It is only used in some versions
|
||||||
|
// of the optimized Conv. This test just mimics something that happens inside
|
||||||
|
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
|
||||||
|
data->need_im2col =
|
||||||
|
!data->need_hwcn_weights &&
|
||||||
|
(params->stride_width != 1 || params->stride_height != 1 ||
|
||||||
|
params->dilation_width_factor != 1 ||
|
||||||
|
params->dilation_height_factor != 1 || filter_width != 1 ||
|
||||||
|
filter_height != 1);
|
||||||
|
|
||||||
int temporaries_count = 0;
|
int temporaries_count = 0;
|
||||||
if (data->need_im2col) {
|
if (data->need_im2col) {
|
||||||
@ -214,7 +216,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||||
|
TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
@ -260,11 +263,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
(input->type == kTfLiteFloat32 &&
|
(input->type == kTfLiteFloat32 &&
|
||||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
|
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
|
||||||
|
|
||||||
data->run_multithreaded_kernel = context->recommended_num_threads != 1;
|
// The multi-threaded kernel supports neither dilation nor hybrid kernels.
|
||||||
// Hybrid kernels don't support multithreading yet.
|
data->supports_multithreaded_kernel =
|
||||||
if (is_hybrid) {
|
(kernel_type == kMultithreadOptimized) &&
|
||||||
data->run_multithreaded_kernel = false;
|
(context->recommended_num_threads != 1) && !is_hybrid &&
|
||||||
}
|
(params->dilation_width_factor == 1) &&
|
||||||
|
(params->dilation_height_factor == 1);
|
||||||
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
AllocateTemporaryTensorsIfRequired(context, node, is_hybrid));
|
AllocateTemporaryTensorsIfRequired(context, node, is_hybrid));
|
||||||
@ -418,6 +422,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
return Prepare(kernel_type, context, node);
|
||||||
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||||
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
|
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
|
||||||
@ -547,18 +556,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||||||
CalculateActivationRange(params->activation, &output_activation_min,
|
CalculateActivationRange(params->activation, &output_activation_min,
|
||||||
&output_activation_max);
|
&output_activation_max);
|
||||||
KernelType effective_kernel_type = kernel_type;
|
KernelType effective_kernel_type = kernel_type;
|
||||||
if (kernel_type == kMultithreadOptimized) {
|
// Fall back to the optimized path if multi-threaded conv is unsupported.
|
||||||
if (context->recommended_num_threads == 1) {
|
if ((kernel_type == kMultithreadOptimized) &&
|
||||||
// Use of kMultithreadOptimized is precomputed during |Prepare()|, whereas
|
!data->supports_multithreaded_kernel) {
|
||||||
// the actual thread count can change at any time. If the client requests
|
effective_kernel_type = kGenericOptimized;
|
||||||
// a single thread (after Prepare()), fall back to optimized.
|
|
||||||
effective_kernel_type = kGenericOptimized;
|
|
||||||
} else if ((params->dilation_width_factor != 1) ||
|
|
||||||
(params->dilation_height_factor != 1)) {
|
|
||||||
// kMultithreadOptimized does not support dilation.
|
|
||||||
// Therefore, fallback to optimized.
|
|
||||||
effective_kernel_type = kGenericOptimized;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ConvParams op_params;
|
ConvParams op_params;
|
||||||
op_params.padding_type = RuntimePaddingType(params->padding);
|
op_params.padding_type = RuntimePaddingType(params->padding);
|
||||||
@ -714,7 +715,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
||||||
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
||||||
bias, im2col, hwcn_weights, output);
|
bias, im2col, hwcn_weights, output);
|
||||||
} else if (data->run_multithreaded_kernel) {
|
} else if (data->supports_multithreaded_kernel) {
|
||||||
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
||||||
im2col, hwcn_weights, output);
|
im2col, hwcn_weights, output);
|
||||||
} else {
|
} else {
|
||||||
@ -741,25 +742,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace conv
|
} // namespace conv
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONVOLUTION_REF() {
|
TfLiteRegistration* Register_CONVOLUTION_REF() {
|
||||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||||
|
conv::Prepare<conv::kReference>,
|
||||||
conv::Eval<conv::kReference>};
|
conv::Eval<conv::kReference>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
|
TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
|
||||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||||
|
conv::Prepare<conv::kGenericOptimized>,
|
||||||
conv::Eval<conv::kGenericOptimized>};
|
conv::Eval<conv::kGenericOptimized>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() {
|
TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() {
|
||||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||||
|
conv::Prepare<conv::kMultithreadOptimized>,
|
||||||
conv::Eval<conv::kMultithreadOptimized>};
|
conv::Eval<conv::kMultithreadOptimized>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() {
|
TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() {
|
||||||
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
|
static TfLiteRegistration r = {conv::Init, conv::Free,
|
||||||
|
conv::Prepare<conv::kCblasOptimized>,
|
||||||
conv::Eval<conv::kCblasOptimized>};
|
conv::Eval<conv::kCblasOptimized>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
@ -50,20 +50,35 @@ void SetEigenNbThreads(int threads) {
|
|||||||
// We have a single global threadpool for all convolution operations. This means
|
// We have a single global threadpool for all convolution operations. This means
|
||||||
// that inferences started from different threads may block each other, but
|
// that inferences started from different threads may block each other, but
|
||||||
// since the underlying resource of CPU cores should be consumed by the
|
// since the underlying resource of CPU cores should be consumed by the
|
||||||
// operations anyway, it shouldn't affect overall performance.
|
// operations anyway, it shouldn't affect overall performance. Note that we
|
||||||
|
// also avoid ThreadPool creation if the target thread count is 1, avoiding
|
||||||
|
// unnecessary overhead, and more closely mimicking Gemmlowp threadpool
|
||||||
|
// behavior.
|
||||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||||
public:
|
public:
|
||||||
// Takes ownership of 'pool'
|
// Takes ownership of 'pool'
|
||||||
explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
|
explicit EigenThreadPoolWrapper(int num_threads) {
|
||||||
|
// Avoid creating any threads for the single-threaded case.
|
||||||
|
if (num_threads > 1) {
|
||||||
|
pool_.reset(new Eigen::ThreadPool(num_threads));
|
||||||
|
}
|
||||||
|
}
|
||||||
~EigenThreadPoolWrapper() override {}
|
~EigenThreadPoolWrapper() override {}
|
||||||
|
|
||||||
void Schedule(std::function<void()> fn) override {
|
void Schedule(std::function<void()> fn) override {
|
||||||
pool_->Schedule(std::move(fn));
|
if (pool_) {
|
||||||
|
pool_->Schedule(std::move(fn));
|
||||||
|
} else {
|
||||||
|
fn();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; }
|
||||||
|
int CurrentThreadId() const override {
|
||||||
|
return pool_ ? pool_->CurrentThreadId() : 0;
|
||||||
}
|
}
|
||||||
int NumThreads() const override { return pool_->NumThreads(); }
|
|
||||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// May be null if num_threads <= 1.
|
||||||
std::unique_ptr<Eigen::ThreadPool> pool_;
|
std::unique_ptr<Eigen::ThreadPool> pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -77,8 +92,8 @@ class LazyEigenThreadPoolHolder {
|
|||||||
// Gets the ThreadPoolDevice, creating if necessary.
|
// Gets the ThreadPoolDevice, creating if necessary.
|
||||||
const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
|
const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
|
||||||
if (!device_) {
|
if (!device_) {
|
||||||
thread_pool_wrapper_.reset(new EigenThreadPoolWrapper(
|
thread_pool_wrapper_.reset(
|
||||||
new Eigen::ThreadPool(target_num_threads_)));
|
new EigenThreadPoolWrapper(target_num_threads_));
|
||||||
device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
|
device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
|
||||||
target_num_threads_));
|
target_num_threads_));
|
||||||
}
|
}
|
||||||
|
145
tensorflow/lite/kernels/eigen_support_test.cc
Normal file
145
tensorflow/lite/kernels/eigen_support_test.cc
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace eigen_support {
|
||||||
|
|
||||||
|
struct TestTfLiteContext : public TfLiteContext {
|
||||||
|
TestTfLiteContext() {
|
||||||
|
recommended_num_threads = -1;
|
||||||
|
external_context = nullptr;
|
||||||
|
GetExternalContext = GetExternalContextImpl;
|
||||||
|
SetExternalContext = SetExternalContextImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void SetExternalContextImpl(TfLiteContext* context,
|
||||||
|
TfLiteExternalContextType type,
|
||||||
|
TfLiteExternalContext* external_context) {
|
||||||
|
static_cast<TestTfLiteContext*>(context)->external_context =
|
||||||
|
external_context;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TfLiteExternalContext* GetExternalContextImpl(
|
||||||
|
TfLiteContext* context, TfLiteExternalContextType type) {
|
||||||
|
return static_cast<TestTfLiteContext*>(context)->external_context;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteExternalContext* external_context;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(EigenSupport, Default) {
|
||||||
|
TestTfLiteContext context;
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
ASSERT_NE(context.external_context, nullptr);
|
||||||
|
EXPECT_EQ(context.external_context->type, kTfLiteEigenContext);
|
||||||
|
|
||||||
|
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||||
|
ASSERT_NE(thread_pool_device, nullptr);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreads(), 4);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EigenSupport, SingleThreaded) {
|
||||||
|
TestTfLiteContext context;
|
||||||
|
context.recommended_num_threads = 1;
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
|
||||||
|
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||||
|
ASSERT_NE(thread_pool_device, nullptr);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreads(), 1);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreadsInPool(), 1);
|
||||||
|
|
||||||
|
bool executed = false;
|
||||||
|
auto notification =
|
||||||
|
thread_pool_device->enqueue([&executed]() { executed = true; });
|
||||||
|
ASSERT_NE(notification, nullptr);
|
||||||
|
notification->Wait();
|
||||||
|
delete notification;
|
||||||
|
EXPECT_TRUE(executed);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EigenSupport, MultiThreaded) {
|
||||||
|
TestTfLiteContext context;
|
||||||
|
context.recommended_num_threads = 2;
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
|
||||||
|
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||||
|
ASSERT_NE(thread_pool_device, nullptr);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreads(), 2);
|
||||||
|
|
||||||
|
bool executed = false;
|
||||||
|
auto notification =
|
||||||
|
thread_pool_device->enqueue([&executed]() { executed = true; });
|
||||||
|
ASSERT_NE(notification, nullptr);
|
||||||
|
notification->Wait();
|
||||||
|
delete notification;
|
||||||
|
EXPECT_TRUE(executed);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EigenSupport, NumThreadsChanged) {
|
||||||
|
TestTfLiteContext context;
|
||||||
|
context.recommended_num_threads = 1;
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
|
||||||
|
auto thread_pool_device = GetThreadPoolDevice(&context);
|
||||||
|
ASSERT_NE(thread_pool_device, nullptr);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreads(), 1);
|
||||||
|
|
||||||
|
context.recommended_num_threads = 3;
|
||||||
|
ASSERT_NE(context.external_context, nullptr);
|
||||||
|
context.external_context->Refresh(&context);
|
||||||
|
thread_pool_device = GetThreadPoolDevice(&context);
|
||||||
|
ASSERT_NE(thread_pool_device, nullptr);
|
||||||
|
EXPECT_EQ(thread_pool_device->numThreads(), 3);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EigenSupport, RefCounting) {
|
||||||
|
TestTfLiteContext context;
|
||||||
|
EXPECT_EQ(context.external_context, nullptr);
|
||||||
|
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
EXPECT_NE(context.external_context, nullptr);
|
||||||
|
|
||||||
|
IncrementUsageCounter(&context);
|
||||||
|
EXPECT_NE(context.external_context, nullptr);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
EXPECT_NE(context.external_context, nullptr);
|
||||||
|
|
||||||
|
DecrementUsageCounter(&context);
|
||||||
|
EXPECT_EQ(context.external_context, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace eigen_support
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -85,12 +85,12 @@ class EigenTensorConvFunctor {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
|
void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
|
||||||
T* im2col_buffer, int input_batches, int input_height,
|
int input_batches, int input_height, int input_width,
|
||||||
int input_width, int input_depth, const T* filter_data,
|
int input_depth, const T* filter_data, int filter_height,
|
||||||
int filter_height, int filter_width, int filter_count,
|
int filter_width, int filter_count, int stride_rows,
|
||||||
int stride_rows, int stride_cols, int pad_width,
|
int stride_cols, int pad_width, int pad_height,
|
||||||
int pad_height, PaddingType padding, T* output_data,
|
PaddingType padding, T* output_data, int output_height,
|
||||||
int output_height, int output_width) {
|
int output_width) {
|
||||||
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
|
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
|
||||||
stride_rows == 1 && stride_cols == 1);
|
stride_rows == 1 && stride_cols == 1);
|
||||||
if (is_1x1_kernel) {
|
if (is_1x1_kernel) {
|
||||||
@ -139,6 +139,9 @@ inline void Conv(const Eigen::ThreadPoolDevice& device,
|
|||||||
const float* bias_data, const RuntimeShape& output_shape,
|
const float* bias_data, const RuntimeShape& output_shape,
|
||||||
float* output_data, const RuntimeShape& im2col_shape,
|
float* output_data, const RuntimeShape& im2col_shape,
|
||||||
float* im2col_data) {
|
float* im2col_data) {
|
||||||
|
// im2col data should not be generated for the multi-thread supporting case.
|
||||||
|
TFLITE_DCHECK(!im2col_data);
|
||||||
|
(void)im2col_shape;
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
const PaddingType padding = params.padding_type;
|
const PaddingType padding = params.padding_type;
|
||||||
@ -160,11 +163,10 @@ inline void Conv(const Eigen::ThreadPoolDevice& device,
|
|||||||
const int output_height = output_shape.Dims(1);
|
const int output_height = output_shape.Dims(1);
|
||||||
const int output_width = output_shape.Dims(2);
|
const int output_width = output_shape.Dims(2);
|
||||||
EigenTensorConvFunctor<float> conv_functor;
|
EigenTensorConvFunctor<float> conv_functor;
|
||||||
conv_functor(device, input_data, im2col_data, batches, input_height,
|
conv_functor(device, input_data, batches, input_height, input_width,
|
||||||
input_width, input_depth, filter_data, filter_height,
|
input_depth, filter_data, filter_height, filter_width,
|
||||||
filter_width, output_depth, stride_height, stride_width,
|
output_depth, stride_height, stride_width, pad_height, pad_width,
|
||||||
pad_height, pad_width, padding, output_data, output_height,
|
padding, output_data, output_height, output_width);
|
||||||
output_width);
|
|
||||||
|
|
||||||
optimized_ops::AddBiasAndEvalActivationFunction(
|
optimized_ops::AddBiasAndEvalActivationFunction(
|
||||||
output_activation_min, output_activation_max, bias_shape, bias_data,
|
output_activation_min, output_activation_max, bias_shape, bias_data,
|
||||||
|
Loading…
Reference in New Issue
Block a user