Add function to check for deprecated paths via weak symbols.
As a first step, define the function and link with CpuBackendContext. PiperOrigin-RevId: 347446604 Change-Id: Ib7fb6820daa18194dd54081d233c8ae960bda5f1
This commit is contained in:
parent
47defbfd18
commit
cae2f149e1
@ -320,6 +320,18 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
# Provide a library for clients to link to if they need to stay on deprecated
|
||||
# arithmetic backends. Include as a dependency of cpu_backend_gemm to start.
|
||||
# TODO(b/168923364): Move to dependent targets.
|
||||
cc_library(
|
||||
name = "deprecated_backends",
|
||||
srcs = [
|
||||
"deprecated_backends.cc",
|
||||
],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_backend_context",
|
||||
srcs = [
|
||||
@ -337,6 +349,7 @@ cc_library(
|
||||
"//conditions:default": ["-DTFLITE_HAVE_CPUINFO"],
|
||||
}),
|
||||
deps = [
|
||||
":deprecated_backends", # TODO(b/168923364): Move to dependent targets.
|
||||
":tflite_with_ruy",
|
||||
":op_macros",
|
||||
# For now this unconditionally depends on both ruy and gemmlowp.
|
||||
@ -345,6 +358,7 @@ cc_library(
|
||||
"@ruy//ruy:context",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite:macros",
|
||||
"//tensorflow/lite:external_cpu_backend_context",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
] + select({
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "public/gemmlowp.h"
|
||||
#include "ruy/context.h" // from @ruy
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/macros.h"
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -35,7 +36,13 @@ const int kDefaultNumThreadpoolThreads = 1;
|
||||
|
||||
namespace tflite {
|
||||
|
||||
#ifdef TFLITE_HAVE_CPUINFO
|
||||
// Use weak symbols if possible to dispatch to deprecated paths.
|
||||
#if TFLITE_HAS_ATTRIBUTE_WEAK && !defined(__APPLE__)
|
||||
extern TFLITE_ATTRIBUTE_WEAK bool UseGemmlowpOnX86();
|
||||
#endif // defined(TFLITE_HAS_ATTRIBUTE_WEAK) && !(__APPLE__)
|
||||
|
||||
// TODO(b/138922878) Enable when Ruy builds on Apple.
|
||||
#if defined(TFLITE_HAVE_CPUINFO) && !defined(__APPLE__)
|
||||
CpuBackendContext::CpuInfo::~CpuInfo() {
|
||||
if (init_status_ == InitStatus::kInitialized) {
|
||||
cpuinfo_deinitialize();
|
||||
@ -144,4 +151,15 @@ bool CpuBackendContext::HasAvxOrAbove() {
|
||||
return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512();
|
||||
}
|
||||
|
||||
bool CpuBackendContext::PreferGemmlowpOnX86() {
|
||||
bool use_gemmlowp_on_x86 = false;
|
||||
#if defined(TFLITE_X86_PLATFORM) && TFLITE_HAS_ATTRIBUTE_WEAK && \
|
||||
!defined(__APPLE__)
|
||||
if (::tflite::UseGemmlowpOnX86 != nullptr) {
|
||||
use_gemmlowp_on_x86 = ::tflite::UseGemmlowpOnX86();
|
||||
}
|
||||
#endif // TFLITE_X86_PLATFORM && TFLITE_HAS_ATTRIBUTE_WEAK && !(__APPLE__)
|
||||
return use_gemmlowp_on_x86 || !HasAvxOrAbove();
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -16,6 +16,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
|
||||
|
||||
#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
|
||||
defined(_M_X64))
|
||||
#define TFLITE_X86_PLATFORM
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
@ -52,6 +57,10 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||
|
||||
bool HasAvxOrAbove();
|
||||
|
||||
// Gemmlowp on x86 is a deprecated path but some clients may still use
|
||||
// this path based on link time dependencies.
|
||||
bool PreferGemmlowpOnX86();
|
||||
|
||||
private:
|
||||
// Copy the wrapper class for cpuinfo from Ruy.
|
||||
class CpuInfo final {
|
||||
|
@ -50,14 +50,7 @@ namespace cpu_backend_gemm {
|
||||
// ENABLED && (AVX
|
||||
// or above available)
|
||||
|
||||
|
||||
#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
|
||||
defined(_M_X64))
|
||||
#define TFLITE_X86_PLATFORM
|
||||
#endif
|
||||
|
||||
// TODO(b/168923364) Set TFLITE_X86_RUY_ENABLED default 'on' when ready.
|
||||
#if defined(TFLITE_X86_PLATFORM) && defined(TFLITE_X86_RUY_ENABLED)
|
||||
#if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
|
||||
/* GEMM dispatch implementation for x86.
|
||||
*/
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
@ -72,12 +65,10 @@ template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
|
||||
DstScalar, quantization_flavor> {};
|
||||
#endif
|
||||
|
||||
#if !defined(TFLITE_WITH_RUY) && !defined(TFLITE_X86_RUY_ENABLED)
|
||||
#if !defined(TFLITE_WITH_RUY)
|
||||
|
||||
/* Specializations using gemmlowp */
|
||||
|
||||
template <typename SrcScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
|
||||
@ -114,7 +105,9 @@ template <>
|
||||
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
|
||||
: detail::GemmImplUsingEigen {};
|
||||
|
||||
#endif // not TFLITE_WITH_RUY && not TFLITE_X86_RUY_ENABLED
|
||||
#endif // not TFLITE_WITH_RUY
|
||||
|
||||
#endif // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
|
||||
|
||||
/* Public entry point */
|
||||
|
||||
|
@ -297,7 +297,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||
// done so far. Until that is done, the best that we can do is to search for
|
||||
// a good exponent value by trial-and-error. This is expensive, as each try
|
||||
// requires computing a whole GEMM. This is thus probably a major contribution
|
||||
// to the overall latency of this tesat. To partially mitigate that,
|
||||
// to the overall latency of this test. To partially mitigate that,
|
||||
// we use a bisection to reduce the required number of tries.
|
||||
//
|
||||
// This function is recursive. The bisect_min and bisect_max arguments
|
||||
|
@ -41,25 +41,27 @@ struct GemmImplX86 {
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
CpuBackendContext* context) {
|
||||
// Run-time dispatch to Ruy for platforms with AVX or above.
|
||||
if (context->HasAvxOrAbove()) {
|
||||
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data,
|
||||
rhs_params, rhs_data,
|
||||
dst_params, dst_data,
|
||||
params, context);
|
||||
} else {
|
||||
// Dispatch to gemmlowp for SSE.
|
||||
// TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated
|
||||
// path is enabled.
|
||||
if (context->PreferGemmlowpOnX86()) {
|
||||
// Dispatch to gemmlowp.
|
||||
detail::GemmImplUsingGemmlowp<
|
||||
LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
|
||||
return;
|
||||
}
|
||||
// Run-time dispatch to Ruy for platforms with AVX or above.
|
||||
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data,
|
||||
rhs_params, rhs_data,
|
||||
dst_params, dst_data,
|
||||
params, context);
|
||||
}
|
||||
};
|
||||
|
||||
// For float, again prefer Ruy in all cases, but defer to eigen if no flavor of
|
||||
// AVX is present.
|
||||
// For float, defer to eigen for now.
|
||||
template <>
|
||||
struct GemmImplX86<float, float, float, float,
|
||||
QuantizationFlavor::kFloatingPoint> {
|
||||
@ -69,19 +71,8 @@ struct GemmImplX86<float, float, float, float,
|
||||
const GemmParams<float, float,
|
||||
QuantizationFlavor::kFloatingPoint>& params,
|
||||
CpuBackendContext* context) {
|
||||
// Run-time dispatch to Ruy for platforms with AVX or above.
|
||||
if (context->HasAvxOrAbove()) {
|
||||
detail::GemmImplUsingRuy<
|
||||
float, float, float, float,
|
||||
QuantizationFlavor::kFloatingPoint>::Run(lhs_params, lhs_data,
|
||||
rhs_params, rhs_data,
|
||||
dst_params, dst_data, params,
|
||||
context);
|
||||
} else {
|
||||
// Dispatch to gemmlowp for SSE.
|
||||
GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
}
|
||||
GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
}
|
||||
};
|
||||
|
||||
|
24
tensorflow/lite/kernels/deprecated_backends.cc
Normal file
24
tensorflow/lite/kernels/deprecated_backends.cc
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Include this target as a dependency in order to define this function for
|
||||
// CpuBackendContext. Its use is to control execution of deprecated paths
|
||||
// by providing a symbol definition for otherwise "weak" symbol
|
||||
// declarations in CpuBackendContext.
|
||||
extern bool UseGemmlowpOnX86() { return true; }
|
||||
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user