Permit runtime opt-in for caching of prepacked matrices in GEMMs.

PiperOrigin-RevId: 314832291
Change-Id: Id0362e33e4d65908a4899d54989658fcfbf7f04c
This commit is contained in:
T.J. Alumbaugh 2020-06-04 17:04:26 -07:00 committed by TensorFlower Gardener
parent 5515595afa
commit 2635f0495e
21 changed files with 132 additions and 71 deletions

View File

@ -35,6 +35,8 @@ class TfLiteInternalBackendContext {
// TfLite computation. // TfLite computation.
virtual void SetMaxNumThreads(int max_num_threads) = 0; virtual void SetMaxNumThreads(int max_num_threads) = 0;
// A context may internally cache prepacked versions of constant tensors for
// faster computation. This function will clear any caches on the context.
virtual void ClearCaches() = 0; virtual void ClearCaches() = 0;
}; };

View File

@ -10,24 +10,27 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
# Enables usage of ruy in TFLite kernels. # Enables usage of ruy exclusively as the GEMM backend in TFLite kernels.
# This will cause TFLite to build with ruy only, providing a smaller binary.
# WARNING: This build flag is experimental and subject to change. # WARNING: This build flag is experimental and subject to change.
config_setting( config_setting(
name = "tflite_with_ruy_explicit_true", name = "tflite_with_ruy_only_explicit_true",
define_values = {"tflite_with_ruy": "true"}, define_values = {"TFLITE_WITH_RUY_ONLY": "true"},
) )
# Disables usage of ruy in TFLite kernels. # Disables usage of ruy as the exclusive GEMM backend in TFLite kernels.
# TFLite will be built with ruy and other GEMM libraries. Ruy will not be
# the default GEMM option at runtime.
# WARNING: This build flag is experimental and subject to change. # WARNING: This build flag is experimental and subject to change.
config_setting( config_setting(
name = "tflite_with_ruy_explicit_false", name = "tflite_with_ruy_only_explicit_false",
define_values = {"tflite_with_ruy": "false"}, define_values = {"TFLITE_WITH_RUY_ONLY": "false"},
) )
###### Beginning of config_setting's to match aarch64 ###### ###### Beginning of config_setting's to match aarch64 ######
# #
# We need to identify the aarch64 instruction set to decide whether to enable # We need to identify the aarch64 instruction set to decide whether to enable
# tflite_with_ruy by default. This is surprisingly hard to do because select() # TFLITE_WITH_RUY_ONLY by default. This is surprisingly hard to do because select()
# can only consume config_setting's, these config_settings are not centralized, # can only consume config_setting's, these config_settings are not centralized,
# and the "cpu" value which they define are free-form strings and there is no # and the "cpu" value which they define are free-form strings and there is no
# standardization of the strings that we need to match for the aarch64 architecture. # standardization of the strings that we need to match for the aarch64 architecture.
@ -229,45 +232,45 @@ cc_test(
) )
cc_library( cc_library(
name = "tflite_with_ruy_enabled", name = "tflite_with_ruy_only_enabled",
build_for_embedded = True, build_for_embedded = True,
defines = ["TFLITE_WITH_RUY"], defines = ["TFLITE_WITH_RUY_ONLY"],
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
) )
cc_library( cc_library(
name = "tflite_with_ruy_and_caching_enabled", name = "tflite_with_ruy_only_and_caching_enabled",
defines = [ defines = [
"TFLITE_WITH_RUY", "TFLITE_WITH_RUY_ONLY",
"TFLITE_WITH_RUY_GEMV", "TFLITE_WITH_RUY_GEMV",
], ],
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
) )
cc_library( cc_library(
name = "tflite_with_ruy_default", name = "tflite_with_ruy_only_default",
build_for_embedded = True, build_for_embedded = True,
select_deps = { select_deps = {
":chromiumos_arm64": [":tflite_with_ruy_enabled"], ":chromiumos_arm64": [":tflite_with_ruy_only_enabled"],
":cpu_aarch64": [":tflite_with_ruy_enabled"], ":cpu_aarch64": [":tflite_with_ruy_only_enabled"],
":cpu_arm64": [":tflite_with_ruy_enabled"], ":cpu_arm64": [":tflite_with_ruy_only_enabled"],
":cpu_arm64e": [":tflite_with_ruy_enabled"], ":cpu_arm64e": [":tflite_with_ruy_only_enabled"],
":cpu_ios_arm64": [":tflite_with_ruy_enabled"], ":cpu_ios_arm64": [":tflite_with_ruy_only_enabled"],
":cpu_ios_arm64e": [":tflite_with_ruy_enabled"], ":cpu_ios_arm64e": [":tflite_with_ruy_only_enabled"],
":cpu_arm64_v8a": [":tflite_with_ruy_enabled"], ":cpu_arm64_v8a": [":tflite_with_ruy_only_enabled"],
"//tensorflow:android_arm": ["tflite_with_ruy_enabled"], "//tensorflow:android_arm": ["tflite_with_ruy_only_enabled"],
"//conditions:default": [], "//conditions:default": [],
}, },
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
) )
cc_library( cc_library(
name = "tflite_with_ruy", name = "tflite_with_ruy_only",
build_for_embedded = True, build_for_embedded = True,
select_deps = { select_deps = {
":tflite_with_ruy_explicit_true": [":tflite_with_ruy_enabled"], ":tflite_with_ruy_only_explicit_true": [":tflite_with_ruy_only_enabled"],
":tflite_with_ruy_explicit_false": [], ":tflite_with_ruy_only_explicit_false": [],
"//conditions:default": [":tflite_with_ruy_default"], "//conditions:default": [":tflite_with_ruy_only_default"],
}, },
) )
@ -281,7 +284,7 @@ cc_library(
], ],
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":tflite_with_ruy", ":tflite_with_ruy_only",
":op_macros", ":op_macros",
# For now this unconditionally depends on both ruy and gemmlowp. # For now this unconditionally depends on both ruy and gemmlowp.
# See the comment inside class CpuBackendContext on the # See the comment inside class CpuBackendContext on the
@ -300,11 +303,11 @@ cc_library(
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":cpu_backend_context", ":cpu_backend_context",
":tflite_with_ruy", ":tflite_with_ruy_only",
"//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/kernels/internal:types",
# For now this unconditionally depends on both ruy and gemmlowp. # For now this unconditionally depends on both ruy and gemmlowp.
# We only need to depend on gemmlowp when tflite_with_ruy # We only need to depend on gemmlowp when tflite_with_ruy_only
# is false, but putting these dependencies in a select() seems to # is false, but putting these dependencies in a select() seems to
# defeat copybara's rewriting rules. # defeat copybara's rewriting rules.
"@ruy//ruy:context", "@ruy//ruy:context",
@ -338,19 +341,19 @@ cc_library(
], ],
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":tflite_with_ruy", ":tflite_with_ruy_only",
"//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/kernels/internal:types",
":cpu_backend_context", ":cpu_backend_context",
":cpu_backend_threadpool", ":cpu_backend_threadpool",
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in # Depend on ruy regardless of `tflite_with_ruy_only`. See the comment in
# cpu_backend_gemm.h about why ruy is the generic path. # cpu_backend_gemm.h about why ruy is the generic path.
"@ruy//ruy", "@ruy//ruy",
"@ruy//ruy:matrix", "@ruy//ruy:matrix",
"@ruy//ruy:path", "@ruy//ruy:path",
"@ruy//ruy/profiler:instrumentation", "@ruy//ruy/profiler:instrumentation",
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy # We only need to depend on gemmlowp and Eigen when tflite_with_ruy_only
# is false, but putting these dependencies in a select() seems to # is false, but putting these dependencies in a select() seems to
# defeat copybara's rewriting rules. # defeat copybara's rewriting rules.
"@gemmlowp", "@gemmlowp",
@ -580,7 +583,7 @@ cc_library(
], ],
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_and_caching_enabled"], deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_only_and_caching_enabled"],
) )
cc_library( cc_library(

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/eigen_support.h" #include "tensorflow/lite/kernels/eigen_support.h"
// b/131835803 forces us to include multithreaded_conv.h before optimized_ops.h // b/131835803 forces us to include multithreaded_conv.h before optimized_ops.h
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
#endif #endif
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
@ -768,8 +768,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
break; break;
} }
case kMultithreadOptimized: { case kMultithreadOptimized: {
#ifdef TFLITE_WITH_RUY #ifdef TFLITE_WITH_RUY_ONLY
// See Register_CONV_2D: we should never be here when tflite_with_ruy // See Register_CONV_2D: we should never be here when TFLITE_WITH_RUY_ONLY
// was enabled. We #if out this code in order to get the corresponding // was enabled. We #if out this code in order to get the corresponding
// binary size benefits. // binary size benefits.
TFLITE_DCHECK(false); TFLITE_DCHECK(false);
@ -1054,8 +1054,8 @@ TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() {
TfLiteRegistration* Register_CONV_2D() { TfLiteRegistration* Register_CONV_2D() {
#if defined TFLITE_USE_APPLE_ACCELERATE_FOR_CONV #if defined TFLITE_USE_APPLE_ACCELERATE_FOR_CONV
return Register_CONVOLUTION_CBLAS_OPT(); return Register_CONVOLUTION_CBLAS_OPT();
#elif defined TFLITE_WITH_RUY #elif defined TFLITE_WITH_RUY_ONLY
// tflite_with_ruy optimizes the generic kernel type. // TFLITE_WITH_RUY_ONLY optimizes the generic kernel type.
return Register_CONVOLUTION_GENERIC_OPT(); return Register_CONVOLUTION_GENERIC_OPT();
#else #else
return Register_CONVOLUTION_MULTITHREADED_OPT(); return Register_CONVOLUTION_MULTITHREADED_OPT();
@ -1066,8 +1066,8 @@ TfLiteRegistration* Register_CONV_2D() {
// models only need the UINT8 type. TFLite's op registration mechanism doesn't // models only need the UINT8 type. TFLite's op registration mechanism doesn't
// yet allow for more nuanced registration mechanisms. // yet allow for more nuanced registration mechanisms.
TfLiteRegistration* Register_CONV_2D_UINT8() { TfLiteRegistration* Register_CONV_2D_UINT8() {
#if defined TFLITE_WITH_RUY #if defined TFLITE_WITH_RUY_ONLY
// tflite_with_ruy optimizes the generic kernel type. // TFLITE_WITH_RUY_ONLY optimizes the generic kernel type.
return Register_CONVOLUTION_GENERIC_OPT_UINT8(); return Register_CONVOLUTION_GENERIC_OPT_UINT8();
#else #else
return Register_CONV_2D(); return Register_CONV_2D();

View File

@ -141,7 +141,7 @@ class ConvolutionOpModel : public BaseConvolutionOpModel<float> {
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({ const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
{"Reference", ops::builtin::Register_CONVOLUTION_REF()}, {"Reference", ops::builtin::Register_CONVOLUTION_REF()},
{"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()},
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
{"MultithreadedOptimized", {"MultithreadedOptimized",
ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()},
#endif #endif

View File

@ -55,6 +55,12 @@ CpuBackendContext::CpuBackendContext()
ruy_context_(new ruy::Context), ruy_context_(new ruy::Context),
gemmlowp_context_(new gemmlowp::GemmContext) { gemmlowp_context_(new gemmlowp::GemmContext) {
SetMaxNumThreads(kDefaultNumThreadpoolThreads); SetMaxNumThreads(kDefaultNumThreadpoolThreads);
// TODO(b/148289189) Remove when clients have transitioned to runtime flag.
#ifdef TFLITE_WITH_RUY_GEMV
SetUseCaching(true);
#else
SetUseCaching(false);
#endif
} }
CpuBackendContext::~CpuBackendContext() {} CpuBackendContext::~CpuBackendContext() {}
@ -67,4 +73,6 @@ void CpuBackendContext::SetMaxNumThreads(int max_num_threads) {
gemmlowp_context_->set_max_num_threads(target_num_threads); gemmlowp_context_->set_max_num_threads(target_num_threads);
} }
void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; }
} // namespace tflite } // namespace tflite

View File

@ -43,6 +43,10 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
int max_num_threads() const { return max_num_threads_; } int max_num_threads() const { return max_num_threads_; }
void SetUseCaching(bool flag);
bool use_caching() const { return use_caching_; }
void ClearCaches() override { ruy_context_->ClearPrepackedCache(); } void ClearCaches() override { ruy_context_->ClearPrepackedCache(); }
private: private:
@ -51,7 +55,7 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
// (see :cpu_backend_gemm), for now a CpuBackendContext always // (see :cpu_backend_gemm), for now a CpuBackendContext always
// stores both a gemmlowp context and a ruy context. // stores both a gemmlowp context and a ruy context.
// TODO(b/131416458): Once call sites all go through abstractions, // TODO(b/131416458): Once call sites all go through abstractions,
// elide what can be elided based on TFLITE_WITH_RUY. // elide what can be elided based on TFLITE_WITH_RUY_ONLY.
const std::unique_ptr<ruy::Context> ruy_context_; const std::unique_ptr<ruy::Context> ruy_context_;
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_; const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
@ -65,6 +69,12 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
// This value also gets propagated to back-ends, where it plays the same // This value also gets propagated to back-ends, where it plays the same
// information-only role. // information-only role.
int max_num_threads_; int max_num_threads_;
// For matrix muliplications with constants parameters (i.e. weights), we can
// sometimes provide speedups by caching the "prepacked" data, for some
// additional memory cost. This flag permits the user to route all
// CpuBackendGem operations to a library that permits such an optimization
// (currently the Ruy library only).
bool use_caching_;
CpuBackendContext(const CpuBackendContext&) = delete; CpuBackendContext(const CpuBackendContext&) = delete;
}; };

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
#endif #endif
@ -41,7 +41,7 @@ template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
DstScalar, quantization_flavor> {}; DstScalar, quantization_flavor> {};
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
/* Specializations using gemmlowp */ /* Specializations using gemmlowp */
@ -81,7 +81,7 @@ template <>
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint> struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
: detail::GemmImplUsingEigen {}; : detail::GemmImplUsingEigen {};
#endif // not TFLITE_WITH_RUY #endif // not TFLITE_WITH_RUY_ONLY
/* Public entry point */ /* Public entry point */
@ -94,12 +94,17 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
CpuBackendContext* context) { CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
ValidateParams(lhs_params, rhs_params, dst_params, params); ValidateParams(lhs_params, rhs_params, dst_params, params);
bool do_custom_gemv = dst_params.cols == 1; if (context->use_caching()) {
#ifdef TFLITE_WITH_RUY_GEMV // Dispatch to backend that supports caching of prepacked weights
// Prefer a Ruy GEMM to Custom GEMV unless we are doing float math. // matrices.
// TODO(b/148692500): Add float GEMV kernels to Ruy. detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
do_custom_gemv = do_custom_gemv && std::is_floating_point<DstScalar>::value; quantization_flavor>::Run(lhs_params, lhs_data,
#endif rhs_params, rhs_data,
dst_params, dst_data,
params, context);
return;
}
const bool do_custom_gemv = (dst_params.cols == 1);
if (do_custom_gemv) { if (do_custom_gemv) {
// GEMV case: try a custom fast GEMV path. // GEMV case: try a custom fast GEMV path.
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,

View File

@ -586,10 +586,10 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
// The float specialization below is unconditionally faster than ruy // The float specialization below is unconditionally faster than ruy
// because ruy does not currently have any Gemv path. // because ruy does not currently have any Gemv path.
// But it is not unconditionally faster than Eigen, which is what is used // But it is not unconditionally faster than Eigen, which is what is used
// unless TFLITE_WITH_RUY is defined. Indeed, Eigen has decently efficient // unless TFLITE_WITH_RUY_ONLY is defined. Indeed, Eigen has decently efficient
// Gemv paths, and they may use AVX instructions, while the present // Gemv paths, and they may use AVX instructions, while the present
// NEON intrinsics code maps at best to SSE4 on x86. // NEON intrinsics code maps at best to SSE4 on x86.
#ifdef TFLITE_WITH_RUY #ifdef TFLITE_WITH_RUY_ONLY
// We want to use fused multiply-add when it's available (that is, on A64 // We want to use fused multiply-add when it's available (that is, on A64
// unconditionally and on A32 with VFPv4) because it's often faster, and // unconditionally and on A32 with VFPv4) because it's often faster, and
@ -773,7 +773,7 @@ struct CustomGemvImpl<float, float, float, float,
} }
}; };
#endif // TFLITE_WITH_RUY #endif // TFLITE_WITH_RUY_ONLY
#endif // USE_NEON #endif // USE_NEON

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
@ -78,4 +78,4 @@ void GemmImplUsingEigen::Run(
} // namespace cpu_backend_gemm } // namespace cpu_backend_gemm
} // namespace tflite } // namespace tflite
#endif // not TFLITE_WITH_RUY #endif // not TFLITE_WITH_RUY_ONLY

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_ #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
@ -37,6 +37,6 @@ struct GemmImplUsingEigen {
} // namespace cpu_backend_gemm } // namespace cpu_backend_gemm
} // namespace tflite } // namespace tflite
#endif // not TFLITE_WITH_RUY #endif // not TFLITE_WITH_RUY_ONLY
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_ #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
@ -188,6 +188,6 @@ struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
} // namespace cpu_backend_gemm } // namespace cpu_backend_gemm
} // namespace tflite } // namespace tflite
#endif // not TFLITE_WITH_RUY #endif // not TFLITE_WITH_RUY_ONLY
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_

View File

@ -42,7 +42,7 @@ inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
template <typename Scalar, typename DataPointer> template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr, void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
ruy::Matrix<Scalar>* dst) { ruy::Matrix<Scalar>* dst, bool use_caching = false) {
ruy::Order ruy_order = params.order == Order::kColMajor ruy::Order ruy_order = params.order == Order::kColMajor
? ruy::Order::kColMajor ? ruy::Order::kColMajor
: ruy::Order::kRowMajor; : ruy::Order::kRowMajor;
@ -52,9 +52,9 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
// It does care whether we assign to it a Scalar* or a const Scalar*. // It does care whether we assign to it a Scalar* or a const Scalar*.
dst->set_data(data_ptr); dst->set_data(data_ptr);
dst->set_zero_point(params.zero_point); dst->set_zero_point(params.zero_point);
#ifdef TFLITE_WITH_RUY_GEMV if (use_caching) {
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
#endif }
} }
template <typename GemmParamsType, typename RuySpecType> template <typename GemmParamsType, typename RuySpecType>
@ -88,8 +88,8 @@ struct GemmImplUsingRuy {
ruy::Matrix<LhsScalar> ruy_lhs; ruy::Matrix<LhsScalar> ruy_lhs;
ruy::Matrix<RhsScalar> ruy_rhs; ruy::Matrix<RhsScalar> ruy_rhs;
ruy::Matrix<DstScalar> ruy_dst; ruy::Matrix<DstScalar> ruy_dst;
MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs); MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching());
MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs); MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching());
MakeRuyMatrix(dst_params, dst_data, &ruy_dst); MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params; ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;

View File

@ -363,7 +363,8 @@ void TestSomeGemm(int rows, int depth, int cols,
CpuBackendContext cpu_backend_context; CpuBackendContext cpu_backend_context;
std::default_random_engine random_engine; std::default_random_engine random_engine;
cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8)); cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
bool use_caching = static_cast<bool>(random_engine() % 2);
cpu_backend_context.SetUseCaching(use_caching);
const bool use_golden = !golden.empty(); const bool use_golden = !golden.empty();
std::vector<LhsScalar> lhs_data; std::vector<LhsScalar> lhs_data;

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/compatibility.h"
#ifdef TFLITE_WITH_RUY #ifdef TFLITE_WITH_RUY_ONLY
#include "ruy/context.h" // from @ruy #include "ruy/context.h" // from @ruy
#include "ruy/thread_pool.h" // from @ruy #include "ruy/thread_pool.h" // from @ruy
#else #else
@ -29,7 +29,7 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace cpu_backend_threadpool { namespace cpu_backend_threadpool {
#ifdef TFLITE_WITH_RUY #ifdef TFLITE_WITH_RUY_ONLY
using Task = ruy::Task; using Task = ruy::Task;
@ -41,7 +41,7 @@ void Execute(int tasks_count, TaskType* tasks,
tasks_count, tasks); tasks_count, tasks);
} }
#else // not TFLITE_WITH_RUY #else // not TFLITE_WITH_RUY_ONLY
using Task = gemmlowp::Task; using Task = gemmlowp::Task;

View File

@ -132,7 +132,7 @@ inline void DepthwiseConv(const DepthwiseParams& params,
int thread_count = HowManyConvThreads(output_shape, filter_shape); int thread_count = HowManyConvThreads(output_shape, filter_shape);
const int max_threads = cpu_backend_context->max_num_threads(); const int max_threads = cpu_backend_context->max_num_threads();
thread_count = std::max(1, std::min(thread_count, max_threads)); thread_count = std::max(1, std::min(thread_count, max_threads));
#ifndef TFLITE_WITH_RUY #ifndef TFLITE_WITH_RUY_ONLY
// Cap the number of threads to 2 for float path to avoid regression in // Cap the number of threads to 2 for float path to avoid regression in
// performance (b/132294857). // performance (b/132294857).
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {

View File

@ -144,13 +144,13 @@ cc_library(
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:platform_profiler",
"//tensorflow/lite/profiling:profile_summary_formatter", "//tensorflow/lite/profiling:profile_summary_formatter",
"//tensorflow/lite/profiling:profiler", "//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/tools:logging", "//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools/delegates:delegate_provider_hdr", "//tensorflow/lite/tools/delegates:delegate_provider_hdr",
"//tensorflow/lite/tools/delegates:tflite_execution_providers", "//tensorflow/lite/tools/delegates:tflite_execution_providers",
"//tensorflow/lite/tools/evaluation:utils",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@ruy//ruy/profiler", "@ruy//ruy/profiler",

View File

@ -34,6 +34,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() {
params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f)); params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f));
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f)); params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1)); params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>("")); params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>("")); params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1)); params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
@ -82,6 +83,11 @@ std::vector<Flag> BenchmarkModel::GetFlags() {
"the end of the run but will not start the next run."), "the end of the run but will not start the next run."),
CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"), CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
CreateFlag<int32_t>("num_threads", &params_, "number of threads"), CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
CreateFlag<bool>(
"use_caching", &params_,
"Enable caching of prepacked weights matrices in matrix "
"multiplication routines. Currently implies the use of the Ruy "
"library."),
CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"), CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
CreateFlag<std::string>("output_prefix", &params_, CreateFlag<std::string>("output_prefix", &params_,
"benchmark output prefix"), "benchmark output prefix"),
@ -108,6 +114,8 @@ void BenchmarkModel::LogParams() {
<< params_.Get<float>("run_delay") << "]"; << params_.Get<float>("run_delay") << "]";
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads") TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
<< "]"; << "]";
TFLITE_LOG(INFO) << "Use caching: [" << params_.Get<bool>("use_caching")
<< "]";
TFLITE_LOG(INFO) << "Benchmark name: [" TFLITE_LOG(INFO) << "Benchmark name: ["
<< params_.Get<std::string>("benchmark_name") << "]"; << params_.Get<std::string>("benchmark_name") << "]";
TFLITE_LOG(INFO) << "Output prefix: [" TFLITE_LOG(INFO) << "Output prefix: ["

View File

@ -53,6 +53,7 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs)); params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f)); params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1)); params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>("")); params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>("")); params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1)); params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
@ -397,6 +398,14 @@ TEST(BenchmarkTest, RunWithWrongFlags) {
EXPECT_EQ(kTfLiteError, status); EXPECT_EQ(kTfLiteError, status);
} }
TEST(BenchmarkTest, RunWithUseCaching) {
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
TestBenchmark benchmark(CreateFp32Params());
ScopedCommandlineArgs scoped_argv({"--use_caching=false"});
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
EXPECT_EQ(kTfLiteOk, status);
}
class MaxDurationWorksTestListener : public BenchmarkListener { class MaxDurationWorksTestListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override { void OnBenchmarkEnd(const BenchmarkResults& results) override {
const int64_t num_actual_runs = results.inference_time_us().count(); const int64_t num_actual_runs = results.inference_time_us().count();

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "ruy/profiler/profiler.h" // from @ruy #include "ruy/profiler/profiler.h" // from @ruy
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/op_resolver.h"
@ -600,11 +601,24 @@ TfLiteStatus BenchmarkTfLiteModel::ResetInputsAndOutputs() {
TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() { TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
auto resolver = GetOpResolver(); auto resolver = GetOpResolver();
const int32_t num_threads = params_.Get<int32_t>("num_threads"); const int32_t num_threads = params_.Get<int32_t>("num_threads");
const bool use_caching = params_.Get<bool>("use_caching");
tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads); tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
if (!interpreter_) { if (!interpreter_) {
TFLITE_LOG(ERROR) << "Failed to initialize the interpreter"; TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
return kTfLiteError; return kTfLiteError;
} }
// Manually enable caching behavior in TF Lite interpreter.
if (use_caching) {
external_context_.reset(new tflite::ExternalCpuBackendContext());
std::unique_ptr<tflite::CpuBackendContext> cpu_backend_context(
new tflite::CpuBackendContext());
cpu_backend_context->SetUseCaching(true);
external_context_->set_internal_backend_context(
std::move(cpu_backend_context));
interpreter_->SetExternalContext(kTfLiteCpuBackendContext,
external_context_.get());
}
return kTfLiteOk; return kTfLiteOk;
} }

View File

@ -85,6 +85,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::unique_ptr<tflite::FlatBufferModel> model_; std::unique_ptr<tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
std::unique_ptr<tflite::ExternalCpuBackendContext> external_context_;
private: private:
// Implement type erasure with unique_ptr with custom deleter. // Implement type erasure with unique_ptr with custom deleter.

View File

@ -187,7 +187,7 @@ ifeq ($(TARGET_ARCH),aarch64)
BUILD_WITH_RUY=true BUILD_WITH_RUY=true
endif endif
ifeq ($(BUILD_WITH_RUY),true) ifeq ($(BUILD_WITH_RUY),true)
CXXFLAGS += -DTFLITE_WITH_RUY CXXFLAGS += -DTFLITE_WITH_RUY_ONLY
endif endif
BUILD_WITH_RUY_PROFILER ?= false BUILD_WITH_RUY_PROFILER ?= false