Add function to check for deprecated paths via weak symbols.

As a first step, define the function and link with CpuBackendContext.

PiperOrigin-RevId: 347446604
Change-Id: Ib7fb6820daa18194dd54081d233c8ae960bda5f1
This commit is contained in:
T.J. Alumbaugh 2020-12-14 12:42:42 -08:00 committed by TensorFlower Gardener
parent 47defbfd18
commit cae2f149e1
7 changed files with 87 additions and 38 deletions

View File

@ -320,6 +320,18 @@ cc_library(
}),
)
# Provide a library for clients to link to if they need to stay on deprecated
# arithmetic backends. Include as a dependency of cpu_backend_gemm to start.
# TODO(b/168923364): Move to dependent targets.
cc_library(
name = "deprecated_backends",
srcs = [
"deprecated_backends.cc",
],
compatible_with = get_compatible_with_portable(),
alwayslink = 1,
)
cc_library(
name = "cpu_backend_context",
srcs = [
@ -337,6 +349,7 @@ cc_library(
"//conditions:default": ["-DTFLITE_HAVE_CPUINFO"],
}),
deps = [
":deprecated_backends", # TODO(b/168923364): Move to dependent targets.
":tflite_with_ruy",
":op_macros",
# For now this unconditionally depends on both ruy and gemmlowp.
@ -345,6 +358,7 @@ cc_library(
"@ruy//ruy:context",
"@gemmlowp",
"//tensorflow/lite/c:common",
"//tensorflow/lite:macros",
"//tensorflow/lite:external_cpu_backend_context",
"//tensorflow/lite/kernels/internal:compatibility",
] + select({

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "ruy/context.h" // from @ruy
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/macros.h"
#include "tensorflow/lite/external_cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
@ -35,7 +36,13 @@ const int kDefaultNumThreadpoolThreads = 1;
namespace tflite {
#ifdef TFLITE_HAVE_CPUINFO
// Use weak symbols if possible to dispatch to deprecated paths.
#if TFLITE_HAS_ATTRIBUTE_WEAK && !defined(__APPLE__)
extern TFLITE_ATTRIBUTE_WEAK bool UseGemmlowpOnX86();
#endif // defined(TFLITE_HAS_ATTRIBUTE_WEAK) && !(__APPLE__)
// TODO(b/138922878) Enable when Ruy builds on Apple.
#if defined(TFLITE_HAVE_CPUINFO) && !defined(__APPLE__)
CpuBackendContext::CpuInfo::~CpuInfo() {
if (init_status_ == InitStatus::kInitialized) {
cpuinfo_deinitialize();
@ -144,4 +151,15 @@ bool CpuBackendContext::HasAvxOrAbove() {
return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512();
}
bool CpuBackendContext::PreferGemmlowpOnX86() {
bool use_gemmlowp_on_x86 = false;
#if defined(TFLITE_X86_PLATFORM) && TFLITE_HAS_ATTRIBUTE_WEAK && \
!defined(__APPLE__)
if (::tflite::UseGemmlowpOnX86 != nullptr) {
use_gemmlowp_on_x86 = ::tflite::UseGemmlowpOnX86();
}
#endif // TFLITE_X86_PLATFORM && TFLITE_HAS_ATTRIBUTE_WEAK && !(__APPLE__)
return use_gemmlowp_on_x86 || !HasAvxOrAbove();
}
} // namespace tflite

View File

@ -16,6 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
defined(_M_X64))
#define TFLITE_X86_PLATFORM
#endif
#include <memory>
#include "public/gemmlowp.h"
@ -52,6 +57,10 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
bool HasAvxOrAbove();
// Gemmlowp on x86 is a deprecated path but some clients may still use
// this path based on link time dependencies.
bool PreferGemmlowpOnX86();
private:
// Copy the wrapper class for cpuinfo from Ruy.
class CpuInfo final {

View File

@ -50,14 +50,7 @@ namespace cpu_backend_gemm {
// ENABLED && (AVX
// or above available)
#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
defined(_M_X64))
#define TFLITE_X86_PLATFORM
#endif
// TODO(b/168923364) Set TFLITE_X86_RUY_ENABLED default 'on' when ready.
#if defined(TFLITE_X86_PLATFORM) && defined(TFLITE_X86_RUY_ENABLED)
#if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
/* GEMM dispatch implementation for x86.
*/
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
@ -72,12 +65,10 @@ template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
DstScalar, quantization_flavor> {};
#endif
#if !defined(TFLITE_WITH_RUY) && !defined(TFLITE_X86_RUY_ENABLED)
#if !defined(TFLITE_WITH_RUY)
/* Specializations using gemmlowp */
template <typename SrcScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
@ -114,7 +105,9 @@ template <>
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
: detail::GemmImplUsingEigen {};
#endif // not TFLITE_WITH_RUY && not TFLITE_X86_RUY_ENABLED
#endif // not TFLITE_WITH_RUY
#endif // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
/* Public entry point */

View File

@ -297,7 +297,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping(
// done so far. Until that is done, the best that we can do is to search for
// a good exponent value by trial-and-error. This is expensive, as each try
// requires computing a whole GEMM. This is thus probably a major contribution
// to the overall latency of this tesat. To partially mitigate that,
// to the overall latency of this test. To partially mitigate that,
// we use a bisection to reduce the required number of tries.
//
// This function is recursive. The bisect_min and bisect_max arguments

View File

@ -41,25 +41,27 @@ struct GemmImplX86 {
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) {
// Run-time dispatch to Ruy for platforms with AVX or above.
if (context->HasAvxOrAbove()) {
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data,
rhs_params, rhs_data,
dst_params, dst_data,
params, context);
} else {
// Dispatch to gemmlowp for SSE.
// TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated
// path is enabled.
if (context->PreferGemmlowpOnX86()) {
// Dispatch to gemmlowp.
detail::GemmImplUsingGemmlowp<
LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context);
return;
}
// Run-time dispatch to Ruy for platforms with AVX or above.
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data,
rhs_params, rhs_data,
dst_params, dst_data,
params, context);
}
};
// For float, again prefer Ruy in all cases, but defer to eigen if no flavor of
// AVX is present.
// For float, defer to eigen for now.
template <>
struct GemmImplX86<float, float, float, float,
QuantizationFlavor::kFloatingPoint> {
@ -69,19 +71,8 @@ struct GemmImplX86<float, float, float, float,
const GemmParams<float, float,
QuantizationFlavor::kFloatingPoint>& params,
CpuBackendContext* context) {
// Run-time dispatch to Ruy for platforms with AVX or above.
if (context->HasAvxOrAbove()) {
detail::GemmImplUsingRuy<
float, float, float, float,
QuantizationFlavor::kFloatingPoint>::Run(lhs_params, lhs_data,
rhs_params, rhs_data,
dst_params, dst_data, params,
context);
} else {
// Dispatch to gemmlowp for SSE.
GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context);
}
GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context);
}
};

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
namespace tflite {
// Include this target as a dependency in order to define this function for
// CpuBackendContext. Its use is to control execution of deprecated paths
// by providing a symbol definition for otherwise "weak" symbol
// declarations in CpuBackendContext.
extern bool UseGemmlowpOnX86() { return true; }
} // namespace tflite