Don't access CpuBackendContext concurrently from multiple threads.
This fixes a race condition that caused crashes with 'illegal instruction' as the dot-product detection went wrong. When implementing dotprod detection in depthwiseconv code based on CpuBackendContext, I forgot to mention that CpuBackendContext should not be used concurrently from multiple threads (a limitation it inherits from the underlying gemmlowp / ruy contexts). To avoid that, the dotprod detection is moved to the top-level op kernel function called on the main thread, before the thread dispatch. A new data structure was needed to hold the results of the dotprod detection in a way that could be shared with threads: that's CpuFlags. Put it in the existing cpu_check.h. It can't share code with the existing code here because what it does is not currently supported by the OS features that this existing code uses. PiperOrigin-RevId: 250739702
This commit is contained in:
parent
583972dc4f
commit
0e43175921
@ -212,6 +212,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":cpu_check",
|
||||||
":quantization_util",
|
":quantization_util",
|
||||||
":strided_slice_logic",
|
":strided_slice_logic",
|
||||||
":types",
|
":types",
|
||||||
@ -254,6 +255,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":cpu_check",
|
||||||
":optimized_base",
|
":optimized_base",
|
||||||
":quantization_util",
|
":quantization_util",
|
||||||
":strided_slice_logic",
|
":strided_slice_logic",
|
||||||
@ -505,10 +507,10 @@ cc_library(
|
|||||||
":types",
|
":types",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/kernels:activation_functor",
|
"//tensorflow/lite/kernels:activation_functor",
|
||||||
|
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
"@arm_neon_2_x86_sse",
|
"@arm_neon_2_x86_sse",
|
||||||
"@gemmlowp//:fixedpoint",
|
"@gemmlowp//:fixedpoint",
|
||||||
"@gemmlowp//:profiler",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -560,12 +562,12 @@ cc_library(
|
|||||||
],
|
],
|
||||||
copts = NEON_FLAGS_IF_APPLICABLE,
|
copts = NEON_FLAGS_IF_APPLICABLE,
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/base:core_headers",
|
":cpu_check",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"@arm_neon_2_x86_sse",
|
"@arm_neon_2_x86_sse",
|
||||||
|
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
"@gemmlowp//:fixedpoint",
|
"@gemmlowp//:fixedpoint",
|
||||||
"@gemmlowp//:profiler",
|
|
||||||
] + select({
|
] + select({
|
||||||
":aarch64": [
|
":aarch64": [
|
||||||
":neon_tensor_utils",
|
":neon_tensor_utils",
|
||||||
@ -650,6 +652,7 @@ cc_test(
|
|||||||
name = "depthwiseconv_float_test",
|
name = "depthwiseconv_float_test",
|
||||||
srcs = ["depthwiseconv_float_test.cc"],
|
srcs = ["depthwiseconv_float_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":cpu_check",
|
||||||
":optimized_base",
|
":optimized_base",
|
||||||
":reference_base",
|
":reference_base",
|
||||||
":test_util",
|
":test_util",
|
||||||
@ -811,6 +814,7 @@ cc_library(
|
|||||||
"optimized/cpu_check.h",
|
"optimized/cpu_check.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||||
] + select(
|
] + select(
|
||||||
{
|
{
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
|
#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||||
|
|
||||||
@ -42,7 +43,8 @@ void TestOneDepthwiseConv(
|
|||||||
reference_output_data.data());
|
reference_output_data.data());
|
||||||
optimized_ops::DepthwiseConvImpl(
|
optimized_ops::DepthwiseConvImpl(
|
||||||
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
||||||
bias_data, output_shape, output_data.data(), nullptr, /*thread_start=*/0,
|
bias_data, output_shape, output_data.data(), CpuFlags(),
|
||||||
|
/*thread_start=*/0,
|
||||||
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
||||||
|
|
||||||
double sum_abs_diff = 0;
|
double sum_abs_diff = 0;
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#ifdef __ANDROID__
|
||||||
@ -44,6 +46,18 @@ inline bool TestCPUFeatureNeon() { return false; }
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
struct CpuFlags {
|
||||||
|
bool neon_dotprod = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline void GetCpuFlags(CpuBackendContext* cpu_backend_context,
|
||||||
|
CpuFlags* cpu_flags) {
|
||||||
|
ruy::Context* ruy_context = cpu_backend_context->ruy_context();
|
||||||
|
cpu_flags->neon_dotprod =
|
||||||
|
ruy_context != nullptr && (ruy_context->GetRuntimeEnabledPaths() &
|
||||||
|
ruy::Path::kNeonDotprod) != ruy::Path::kNone;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both
|
// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both
|
||||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
|
||||||
|
|
||||||
#include "profiling/instrumentation.h"
|
#include "profiling/instrumentation.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -897,7 +897,7 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
|
|||||||
// For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
|
// For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
|
||||||
// means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
|
// means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
|
||||||
//
|
//
|
||||||
// The cpu_backend_context may be supplied as a nullptr by some callers. This
|
// The cpu_flags is currently unused. This
|
||||||
// parameter is included so that the signature matches that required by a
|
// parameter is included so that the signature matches that required by a
|
||||||
// templated function. Other versions, such as quantized, need this parameter.
|
// templated function. Other versions, such as quantized, need this parameter.
|
||||||
inline void DepthwiseConvImpl(
|
inline void DepthwiseConvImpl(
|
||||||
@ -905,8 +905,8 @@ inline void DepthwiseConvImpl(
|
|||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
const float* filter_data, const RuntimeShape& bias_shape,
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
const float* bias_data, const RuntimeShape& output_shape,
|
const float* bias_data, const RuntimeShape& output_shape,
|
||||||
float* output_data, CpuBackendContext* cpu_backend_context,
|
float* output_data, const CpuFlags& /* cpu_flags */, int thread_start,
|
||||||
int thread_start, int thread_end, int thread_dim) {
|
int thread_end, int thread_dim) {
|
||||||
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/float/DepthwiseConvImpl");
|
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/float/DepthwiseConvImpl");
|
||||||
|
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
@ -1116,10 +1116,9 @@ inline void DepthwiseConv(
|
|||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
const float* filter_data, const RuntimeShape& bias_shape,
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
const float* bias_data, const RuntimeShape& output_shape,
|
const float* bias_data, const RuntimeShape& output_shape,
|
||||||
float* output_data, CpuBackendContext* cpu_backend_context) {
|
float* output_data, const CpuFlags& cpu_flags) {
|
||||||
DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
|
DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
|
||||||
bias_shape, bias_data, output_shape, output_data,
|
bias_shape, bias_data, output_shape, output_data, cpu_flags,
|
||||||
cpu_backend_context,
|
|
||||||
/*thread_start=*/0,
|
/*thread_start=*/0,
|
||||||
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
|
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
|
||||||
|
|
||||||
@ -36,8 +37,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
|
|||||||
const RuntimeShape& filter_shape,
|
const RuntimeShape& filter_shape,
|
||||||
const T* filter_data, const RuntimeShape& bias_shape,
|
const T* filter_data, const RuntimeShape& bias_shape,
|
||||||
const TS* bias_data, const RuntimeShape& output_shape,
|
const TS* bias_data, const RuntimeShape& output_shape,
|
||||||
T* output_data,
|
T* output_data, const CpuFlags& cpu_flags,
|
||||||
CpuBackendContext* cpu_backend_context,
|
|
||||||
int thread_start, int thread_end, int thread_dim)
|
int thread_start, int thread_end, int thread_dim)
|
||||||
: params_(params),
|
: params_(params),
|
||||||
input_shape_(input_shape),
|
input_shape_(input_shape),
|
||||||
@ -48,7 +48,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
|
|||||||
bias_data_(bias_data),
|
bias_data_(bias_data),
|
||||||
output_shape_(output_shape),
|
output_shape_(output_shape),
|
||||||
output_data_(output_data),
|
output_data_(output_data),
|
||||||
cpu_backend_context_(cpu_backend_context),
|
cpu_flags_(cpu_flags),
|
||||||
thread_start_(thread_start),
|
thread_start_(thread_start),
|
||||||
thread_end_(thread_end),
|
thread_end_(thread_end),
|
||||||
thread_dim_(thread_dim) {}
|
thread_dim_(thread_dim) {}
|
||||||
@ -56,8 +56,8 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
|
|||||||
void Run() override {
|
void Run() override {
|
||||||
DepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
|
DepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
|
||||||
filter_data_, bias_shape_, bias_data_, output_shape_,
|
filter_data_, bias_shape_, bias_data_, output_shape_,
|
||||||
output_data_, cpu_backend_context_, thread_start_,
|
output_data_, cpu_flags_, thread_start_, thread_end_,
|
||||||
thread_end_, thread_dim_);
|
thread_dim_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -70,7 +70,7 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
|
|||||||
const TS* bias_data_;
|
const TS* bias_data_;
|
||||||
const RuntimeShape& output_shape_;
|
const RuntimeShape& output_shape_;
|
||||||
T* output_data_;
|
T* output_data_;
|
||||||
CpuBackendContext* cpu_backend_context_;
|
const CpuFlags& cpu_flags_;
|
||||||
int thread_start_;
|
int thread_start_;
|
||||||
int thread_end_;
|
int thread_end_;
|
||||||
int thread_dim_;
|
int thread_dim_;
|
||||||
@ -143,10 +143,13 @@ inline void DepthwiseConv(const DepthwiseParams& params,
|
|||||||
const int output_batches = output_shape.Dims(0);
|
const int output_batches = output_shape.Dims(0);
|
||||||
const int output_height = output_shape.Dims(1);
|
const int output_height = output_shape.Dims(1);
|
||||||
|
|
||||||
|
CpuFlags cpu_flags;
|
||||||
|
GetCpuFlags(cpu_backend_context, &cpu_flags);
|
||||||
|
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
DepthwiseConvImpl(params, input_shape, input_data, filter_shape,
|
DepthwiseConvImpl(params, input_shape, input_data, filter_shape,
|
||||||
filter_data, bias_shape, bias_data, output_shape,
|
filter_data, bias_shape, bias_data, output_shape,
|
||||||
output_data, cpu_backend_context, /*thread_start=*/0,
|
output_data, cpu_flags, /*thread_start=*/0,
|
||||||
/*thread_end=*/output_height, /*thread_dim=*/1);
|
/*thread_end=*/output_height, /*thread_dim=*/1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -170,8 +173,8 @@ inline void DepthwiseConv(const DepthwiseParams& params,
|
|||||||
thread_start + (thread_dim_size - thread_start) / (thread_count - i);
|
thread_start + (thread_dim_size - thread_start) / (thread_count - i);
|
||||||
tasks.emplace_back(params, input_shape, input_data, filter_shape,
|
tasks.emplace_back(params, input_shape, input_data, filter_shape,
|
||||||
filter_data, bias_shape, bias_data, output_shape,
|
filter_data, bias_shape, bias_data, output_shape,
|
||||||
output_data, cpu_backend_context, thread_start,
|
output_data, cpu_flags, thread_start, thread_end,
|
||||||
thread_end, thread_dim);
|
thread_dim);
|
||||||
thread_start = thread_end;
|
thread_start = thread_end;
|
||||||
}
|
}
|
||||||
cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
|
cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "profiling/instrumentation.h"
|
#include "profiling/instrumentation.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
@ -1986,8 +1986,8 @@ inline void DepthwiseConvWithRounding(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, CpuBackendContext* cpu_backend_context,
|
uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
|
||||||
int thread_start, int thread_end, int thread_dim) {
|
int thread_end, int thread_dim) {
|
||||||
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
|
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
|
||||||
const int depth_multiplier = params.depth_multiplier;
|
const int depth_multiplier = params.depth_multiplier;
|
||||||
const int32 output_activation_min = params.quantized_activation_min;
|
const int32 output_activation_min = params.quantized_activation_min;
|
||||||
@ -2009,12 +2009,7 @@ inline void DepthwiseConvWithRounding(
|
|||||||
// Jetson TX-2. This compiler does not support the offsetof() macro.
|
// Jetson TX-2. This compiler does not support the offsetof() macro.
|
||||||
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
|
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
|
||||||
// Dispatch to dot-product 3x3 kernels when supported.
|
// Dispatch to dot-product 3x3 kernels when supported.
|
||||||
|
if (cpu_flags.neon_dotprod) {
|
||||||
ruy::Context* ruy_context = cpu_backend_context->ruy_context();
|
|
||||||
const bool has_dot_product_instructions =
|
|
||||||
ruy_context != nullptr && (ruy_context->GetRuntimeEnabledPaths() &
|
|
||||||
ruy::Path::kNeonDotprod) != ruy::Path::kNone;
|
|
||||||
if (has_dot_product_instructions) {
|
|
||||||
using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
|
using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
|
||||||
DotProduct3x3KernelType kernel_type =
|
DotProduct3x3KernelType kernel_type =
|
||||||
optimized_ops::depthwise_conv::CategorizeDotProductKernel(
|
optimized_ops::depthwise_conv::CategorizeDotProductKernel(
|
||||||
@ -2067,12 +2062,12 @@ inline void DepthwiseConvImpl(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, CpuBackendContext* cpu_backend_context,
|
uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
|
||||||
int thread_start, int thread_end, int thread_dim) {
|
int thread_end, int thread_dim) {
|
||||||
return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kUpward>(
|
return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kUpward>(
|
||||||
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
||||||
bias_data, output_shape, output_data, cpu_backend_context, thread_start,
|
bias_data, output_shape, output_data, cpu_flags, thread_start, thread_end,
|
||||||
thread_end, thread_dim);
|
thread_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DepthwiseConv(const DepthwiseParams& params,
|
void DepthwiseConv(const DepthwiseParams& params,
|
||||||
@ -2080,7 +2075,7 @@ void DepthwiseConv(const DepthwiseParams& params,
|
|||||||
const RuntimeShape& filter_shape, const uint8* filter_data,
|
const RuntimeShape& filter_shape, const uint8* filter_data,
|
||||||
const RuntimeShape& bias_shape, const int32* bias_data,
|
const RuntimeShape& bias_shape, const int32* bias_data,
|
||||||
const RuntimeShape& output_shape, uint8* output_data,
|
const RuntimeShape& output_shape, uint8* output_data,
|
||||||
CpuBackendContext* cpu_backend_context);
|
const CpuFlags& cpu_flags);
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
|
||||||
@ -165,7 +166,7 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
|
|||||||
DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
|
DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
|
||||||
DimsToShape(filter_dims), filter_data,
|
DimsToShape(filter_dims), filter_data,
|
||||||
DimsToShape(bias_dims), bias_data, output_shape,
|
DimsToShape(bias_dims), bias_data, output_shape,
|
||||||
output_data, nullptr, /*thread_start=*/0,
|
output_data, CpuFlags(), /*thread_start=*/0,
|
||||||
/*thread_end=*/output_height, /*thread_dim=*/1);
|
/*thread_end=*/output_height, /*thread_dim=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -598,7 +599,8 @@ inline void DepthwiseConv(
|
|||||||
const float* bias_data, const RuntimeShape& output_shape,
|
const float* bias_data, const RuntimeShape& output_shape,
|
||||||
float* output_data) {
|
float* output_data) {
|
||||||
DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
|
DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
|
||||||
bias_shape, bias_data, output_shape, output_data, nullptr,
|
bias_shape, bias_data, output_shape, output_data,
|
||||||
|
CpuFlags(),
|
||||||
/*thread_start=*/0,
|
/*thread_start=*/0,
|
||||||
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
/*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user