Permit runtime opt-in for caching of prepacked matrices in GEMMs.
PiperOrigin-RevId: 314832291 Change-Id: Id0362e33e4d65908a4899d54989658fcfbf7f04c
This commit is contained in:
parent
5515595afa
commit
2635f0495e
|
@ -35,6 +35,8 @@ class TfLiteInternalBackendContext {
|
||||||
// TfLite computation.
|
// TfLite computation.
|
||||||
virtual void SetMaxNumThreads(int max_num_threads) = 0;
|
virtual void SetMaxNumThreads(int max_num_threads) = 0;
|
||||||
|
|
||||||
|
// A context may internally cache prepacked versions of constant tensors for
|
||||||
|
// faster computation. This function will clear any caches on the context.
|
||||||
virtual void ClearCaches() = 0;
|
virtual void ClearCaches() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -10,24 +10,27 @@ package(
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enables usage of ruy in TFLite kernels.
|
# Enables usage of ruy exclusively as the GEMM backend in TFLite kernels.
|
||||||
|
# This will cause TFLite to build with ruy only, providing a smaller binary.
|
||||||
# WARNING: This build flag is experimental and subject to change.
|
# WARNING: This build flag is experimental and subject to change.
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "tflite_with_ruy_explicit_true",
|
name = "tflite_with_ruy_only_explicit_true",
|
||||||
define_values = {"tflite_with_ruy": "true"},
|
define_values = {"TFLITE_WITH_RUY_ONLY": "true"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Disables usage of ruy in TFLite kernels.
|
# Disables usage of ruy as the exclusive GEMM backend in TFLite kernels.
|
||||||
|
# TFLite will be built with ruy and other GEMM libraries. Ruy will not be
|
||||||
|
# the default GEMM option at runtime.
|
||||||
# WARNING: This build flag is experimental and subject to change.
|
# WARNING: This build flag is experimental and subject to change.
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "tflite_with_ruy_explicit_false",
|
name = "tflite_with_ruy_only_explicit_false",
|
||||||
define_values = {"tflite_with_ruy": "false"},
|
define_values = {"TFLITE_WITH_RUY_ONLY": "false"},
|
||||||
)
|
)
|
||||||
|
|
||||||
###### Beginning of config_setting's to match aarch64 ######
|
###### Beginning of config_setting's to match aarch64 ######
|
||||||
#
|
#
|
||||||
# We need to identify the aarch64 instruction set to decide whether to enable
|
# We need to identify the aarch64 instruction set to decide whether to enable
|
||||||
# tflite_with_ruy by default. This is surprisingly hard to do because select()
|
# TFLITE_WITH_RUY_ONLY by default. This is surprisingly hard to do because select()
|
||||||
# can only consume config_setting's, these config_settings are not centralized,
|
# can only consume config_setting's, these config_settings are not centralized,
|
||||||
# and the "cpu" value which they define are free-form strings and there is no
|
# and the "cpu" value which they define are free-form strings and there is no
|
||||||
# standardization of the strings that we need to match for the aarch64 architecture.
|
# standardization of the strings that we need to match for the aarch64 architecture.
|
||||||
|
@ -229,45 +232,45 @@ cc_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_with_ruy_enabled",
|
name = "tflite_with_ruy_only_enabled",
|
||||||
build_for_embedded = True,
|
build_for_embedded = True,
|
||||||
defines = ["TFLITE_WITH_RUY"],
|
defines = ["TFLITE_WITH_RUY_ONLY"],
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_with_ruy_and_caching_enabled",
|
name = "tflite_with_ruy_only_and_caching_enabled",
|
||||||
defines = [
|
defines = [
|
||||||
"TFLITE_WITH_RUY",
|
"TFLITE_WITH_RUY_ONLY",
|
||||||
"TFLITE_WITH_RUY_GEMV",
|
"TFLITE_WITH_RUY_GEMV",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_with_ruy_default",
|
name = "tflite_with_ruy_only_default",
|
||||||
build_for_embedded = True,
|
build_for_embedded = True,
|
||||||
select_deps = {
|
select_deps = {
|
||||||
":chromiumos_arm64": [":tflite_with_ruy_enabled"],
|
":chromiumos_arm64": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_aarch64": [":tflite_with_ruy_enabled"],
|
":cpu_aarch64": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_arm64": [":tflite_with_ruy_enabled"],
|
":cpu_arm64": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_arm64e": [":tflite_with_ruy_enabled"],
|
":cpu_arm64e": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_ios_arm64": [":tflite_with_ruy_enabled"],
|
":cpu_ios_arm64": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_ios_arm64e": [":tflite_with_ruy_enabled"],
|
":cpu_ios_arm64e": [":tflite_with_ruy_only_enabled"],
|
||||||
":cpu_arm64_v8a": [":tflite_with_ruy_enabled"],
|
":cpu_arm64_v8a": [":tflite_with_ruy_only_enabled"],
|
||||||
"//tensorflow:android_arm": ["tflite_with_ruy_enabled"],
|
"//tensorflow:android_arm": ["tflite_with_ruy_only_enabled"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
},
|
},
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_with_ruy",
|
name = "tflite_with_ruy_only",
|
||||||
build_for_embedded = True,
|
build_for_embedded = True,
|
||||||
select_deps = {
|
select_deps = {
|
||||||
":tflite_with_ruy_explicit_true": [":tflite_with_ruy_enabled"],
|
":tflite_with_ruy_only_explicit_true": [":tflite_with_ruy_only_enabled"],
|
||||||
":tflite_with_ruy_explicit_false": [],
|
":tflite_with_ruy_only_explicit_false": [],
|
||||||
"//conditions:default": [":tflite_with_ruy_default"],
|
"//conditions:default": [":tflite_with_ruy_only_default"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -281,7 +284,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":tflite_with_ruy",
|
":tflite_with_ruy_only",
|
||||||
":op_macros",
|
":op_macros",
|
||||||
# For now this unconditionally depends on both ruy and gemmlowp.
|
# For now this unconditionally depends on both ruy and gemmlowp.
|
||||||
# See the comment inside class CpuBackendContext on the
|
# See the comment inside class CpuBackendContext on the
|
||||||
|
@ -300,11 +303,11 @@ cc_library(
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":cpu_backend_context",
|
":cpu_backend_context",
|
||||||
":tflite_with_ruy",
|
":tflite_with_ruy_only",
|
||||||
"//tensorflow/lite/kernels/internal:compatibility",
|
"//tensorflow/lite/kernels/internal:compatibility",
|
||||||
"//tensorflow/lite/kernels/internal:types",
|
"//tensorflow/lite/kernels/internal:types",
|
||||||
# For now this unconditionally depends on both ruy and gemmlowp.
|
# For now this unconditionally depends on both ruy and gemmlowp.
|
||||||
# We only need to depend on gemmlowp when tflite_with_ruy
|
# We only need to depend on gemmlowp when tflite_with_ruy_only
|
||||||
# is false, but putting these dependencies in a select() seems to
|
# is false, but putting these dependencies in a select() seems to
|
||||||
# defeat copybara's rewriting rules.
|
# defeat copybara's rewriting rules.
|
||||||
"@ruy//ruy:context",
|
"@ruy//ruy:context",
|
||||||
|
@ -338,19 +341,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":tflite_with_ruy",
|
":tflite_with_ruy_only",
|
||||||
"//tensorflow/lite/kernels/internal:common",
|
"//tensorflow/lite/kernels/internal:common",
|
||||||
"//tensorflow/lite/kernels/internal:compatibility",
|
"//tensorflow/lite/kernels/internal:compatibility",
|
||||||
"//tensorflow/lite/kernels/internal:types",
|
"//tensorflow/lite/kernels/internal:types",
|
||||||
":cpu_backend_context",
|
":cpu_backend_context",
|
||||||
":cpu_backend_threadpool",
|
":cpu_backend_threadpool",
|
||||||
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
|
# Depend on ruy regardless of `tflite_with_ruy_only`. See the comment in
|
||||||
# cpu_backend_gemm.h about why ruy is the generic path.
|
# cpu_backend_gemm.h about why ruy is the generic path.
|
||||||
"@ruy//ruy",
|
"@ruy//ruy",
|
||||||
"@ruy//ruy:matrix",
|
"@ruy//ruy:matrix",
|
||||||
"@ruy//ruy:path",
|
"@ruy//ruy:path",
|
||||||
"@ruy//ruy/profiler:instrumentation",
|
"@ruy//ruy/profiler:instrumentation",
|
||||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy_only
|
||||||
# is false, but putting these dependencies in a select() seems to
|
# is false, but putting these dependencies in a select() seems to
|
||||||
# defeat copybara's rewriting rules.
|
# defeat copybara's rewriting rules.
|
||||||
"@gemmlowp",
|
"@gemmlowp",
|
||||||
|
@ -580,7 +583,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
|
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_and_caching_enabled"],
|
deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_only_and_caching_enabled"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
// b/131835803 forces us to include multithreaded_conv.h before optimized_ops.h
|
// b/131835803 forces us to include multithreaded_conv.h before optimized_ops.h
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
@ -768,8 +768,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kMultithreadOptimized: {
|
case kMultithreadOptimized: {
|
||||||
#ifdef TFLITE_WITH_RUY
|
#ifdef TFLITE_WITH_RUY_ONLY
|
||||||
// See Register_CONV_2D: we should never be here when tflite_with_ruy
|
// See Register_CONV_2D: we should never be here when TFLITE_WITH_RUY_ONLY
|
||||||
// was enabled. We #if out this code in order to get the corresponding
|
// was enabled. We #if out this code in order to get the corresponding
|
||||||
// binary size benefits.
|
// binary size benefits.
|
||||||
TFLITE_DCHECK(false);
|
TFLITE_DCHECK(false);
|
||||||
|
@ -1054,8 +1054,8 @@ TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() {
|
||||||
TfLiteRegistration* Register_CONV_2D() {
|
TfLiteRegistration* Register_CONV_2D() {
|
||||||
#if defined TFLITE_USE_APPLE_ACCELERATE_FOR_CONV
|
#if defined TFLITE_USE_APPLE_ACCELERATE_FOR_CONV
|
||||||
return Register_CONVOLUTION_CBLAS_OPT();
|
return Register_CONVOLUTION_CBLAS_OPT();
|
||||||
#elif defined TFLITE_WITH_RUY
|
#elif defined TFLITE_WITH_RUY_ONLY
|
||||||
// tflite_with_ruy optimizes the generic kernel type.
|
// TFLITE_WITH_RUY_ONLY optimizes the generic kernel type.
|
||||||
return Register_CONVOLUTION_GENERIC_OPT();
|
return Register_CONVOLUTION_GENERIC_OPT();
|
||||||
#else
|
#else
|
||||||
return Register_CONVOLUTION_MULTITHREADED_OPT();
|
return Register_CONVOLUTION_MULTITHREADED_OPT();
|
||||||
|
@ -1066,8 +1066,8 @@ TfLiteRegistration* Register_CONV_2D() {
|
||||||
// models only need the UINT8 type. TFLite's op registration mechanism doesn't
|
// models only need the UINT8 type. TFLite's op registration mechanism doesn't
|
||||||
// yet allow for more nuanced registration mechanisms.
|
// yet allow for more nuanced registration mechanisms.
|
||||||
TfLiteRegistration* Register_CONV_2D_UINT8() {
|
TfLiteRegistration* Register_CONV_2D_UINT8() {
|
||||||
#if defined TFLITE_WITH_RUY
|
#if defined TFLITE_WITH_RUY_ONLY
|
||||||
// tflite_with_ruy optimizes the generic kernel type.
|
// TFLITE_WITH_RUY_ONLY optimizes the generic kernel type.
|
||||||
return Register_CONVOLUTION_GENERIC_OPT_UINT8();
|
return Register_CONVOLUTION_GENERIC_OPT_UINT8();
|
||||||
#else
|
#else
|
||||||
return Register_CONV_2D();
|
return Register_CONV_2D();
|
||||||
|
|
|
@ -141,7 +141,7 @@ class ConvolutionOpModel : public BaseConvolutionOpModel<float> {
|
||||||
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
|
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
|
||||||
{"Reference", ops::builtin::Register_CONVOLUTION_REF()},
|
{"Reference", ops::builtin::Register_CONVOLUTION_REF()},
|
||||||
{"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()},
|
{"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()},
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
{"MultithreadedOptimized",
|
{"MultithreadedOptimized",
|
||||||
ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()},
|
ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()},
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -55,6 +55,12 @@ CpuBackendContext::CpuBackendContext()
|
||||||
ruy_context_(new ruy::Context),
|
ruy_context_(new ruy::Context),
|
||||||
gemmlowp_context_(new gemmlowp::GemmContext) {
|
gemmlowp_context_(new gemmlowp::GemmContext) {
|
||||||
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
||||||
|
// TODO(b/148289189) Remove when clients have transitioned to runtime flag.
|
||||||
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
|
SetUseCaching(true);
|
||||||
|
#else
|
||||||
|
SetUseCaching(false);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
CpuBackendContext::~CpuBackendContext() {}
|
CpuBackendContext::~CpuBackendContext() {}
|
||||||
|
@ -67,4 +73,6 @@ void CpuBackendContext::SetMaxNumThreads(int max_num_threads) {
|
||||||
gemmlowp_context_->set_max_num_threads(target_num_threads);
|
gemmlowp_context_->set_max_num_threads(target_num_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; }
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
|
@ -43,6 +43,10 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||||
|
|
||||||
int max_num_threads() const { return max_num_threads_; }
|
int max_num_threads() const { return max_num_threads_; }
|
||||||
|
|
||||||
|
void SetUseCaching(bool flag);
|
||||||
|
|
||||||
|
bool use_caching() const { return use_caching_; }
|
||||||
|
|
||||||
void ClearCaches() override { ruy_context_->ClearPrepackedCache(); }
|
void ClearCaches() override { ruy_context_->ClearPrepackedCache(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -51,7 +55,7 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||||
// (see :cpu_backend_gemm), for now a CpuBackendContext always
|
// (see :cpu_backend_gemm), for now a CpuBackendContext always
|
||||||
// stores both a gemmlowp context and a ruy context.
|
// stores both a gemmlowp context and a ruy context.
|
||||||
// TODO(b/131416458): Once call sites all go through abstractions,
|
// TODO(b/131416458): Once call sites all go through abstractions,
|
||||||
// elide what can be elided based on TFLITE_WITH_RUY.
|
// elide what can be elided based on TFLITE_WITH_RUY_ONLY.
|
||||||
const std::unique_ptr<ruy::Context> ruy_context_;
|
const std::unique_ptr<ruy::Context> ruy_context_;
|
||||||
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
|
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
|
||||||
|
|
||||||
|
@ -65,6 +69,12 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||||
// This value also gets propagated to back-ends, where it plays the same
|
// This value also gets propagated to back-ends, where it plays the same
|
||||||
// information-only role.
|
// information-only role.
|
||||||
int max_num_threads_;
|
int max_num_threads_;
|
||||||
|
// For matrix muliplications with constants parameters (i.e. weights), we can
|
||||||
|
// sometimes provide speedups by caching the "prepacked" data, for some
|
||||||
|
// additional memory cost. This flag permits the user to route all
|
||||||
|
// CpuBackendGem operations to a library that permits such an optimization
|
||||||
|
// (currently the Ruy library only).
|
||||||
|
bool use_caching_;
|
||||||
|
|
||||||
CpuBackendContext(const CpuBackendContext&) = delete;
|
CpuBackendContext(const CpuBackendContext&) = delete;
|
||||||
};
|
};
|
||||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -41,7 +41,7 @@ template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
|
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
|
||||||
DstScalar, quantization_flavor> {};
|
DstScalar, quantization_flavor> {};
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
/* Specializations using gemmlowp */
|
/* Specializations using gemmlowp */
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ template <>
|
||||||
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
|
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
|
||||||
: detail::GemmImplUsingEigen {};
|
: detail::GemmImplUsingEigen {};
|
||||||
|
|
||||||
#endif // not TFLITE_WITH_RUY
|
#endif // not TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
/* Public entry point */
|
/* Public entry point */
|
||||||
|
|
||||||
|
@ -94,12 +94,17 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
||||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||||
bool do_custom_gemv = dst_params.cols == 1;
|
if (context->use_caching()) {
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
// Dispatch to backend that supports caching of prepacked weights
|
||||||
// Prefer a Ruy GEMM to Custom GEMV unless we are doing float math.
|
// matrices.
|
||||||
// TODO(b/148692500): Add float GEMV kernels to Ruy.
|
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
do_custom_gemv = do_custom_gemv && std::is_floating_point<DstScalar>::value;
|
quantization_flavor>::Run(lhs_params, lhs_data,
|
||||||
#endif
|
rhs_params, rhs_data,
|
||||||
|
dst_params, dst_data,
|
||||||
|
params, context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const bool do_custom_gemv = (dst_params.cols == 1);
|
||||||
if (do_custom_gemv) {
|
if (do_custom_gemv) {
|
||||||
// GEMV case: try a custom fast GEMV path.
|
// GEMV case: try a custom fast GEMV path.
|
||||||
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
|
|
|
@ -586,10 +586,10 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
|
||||||
// The float specialization below is unconditionally faster than ruy
|
// The float specialization below is unconditionally faster than ruy
|
||||||
// because ruy does not currently have any Gemv path.
|
// because ruy does not currently have any Gemv path.
|
||||||
// But it is not unconditionally faster than Eigen, which is what is used
|
// But it is not unconditionally faster than Eigen, which is what is used
|
||||||
// unless TFLITE_WITH_RUY is defined. Indeed, Eigen has decently efficient
|
// unless TFLITE_WITH_RUY_ONLY is defined. Indeed, Eigen has decently efficient
|
||||||
// Gemv paths, and they may use AVX instructions, while the present
|
// Gemv paths, and they may use AVX instructions, while the present
|
||||||
// NEON intrinsics code maps at best to SSE4 on x86.
|
// NEON intrinsics code maps at best to SSE4 on x86.
|
||||||
#ifdef TFLITE_WITH_RUY
|
#ifdef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
// We want to use fused multiply-add when it's available (that is, on A64
|
// We want to use fused multiply-add when it's available (that is, on A64
|
||||||
// unconditionally and on A32 with VFPv4) because it's often faster, and
|
// unconditionally and on A32 with VFPv4) because it's often faster, and
|
||||||
|
@ -773,7 +773,7 @@ struct CustomGemvImpl<float, float, float, float,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // TFLITE_WITH_RUY
|
#endif // TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#endif // USE_NEON
|
#endif // USE_NEON
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
|
||||||
|
|
||||||
|
@ -78,4 +78,4 @@ void GemmImplUsingEigen::Run(
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // not TFLITE_WITH_RUY
|
#endif // not TFLITE_WITH_RUY_ONLY
|
||||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
|
@ -37,6 +37,6 @@ struct GemmImplUsingEigen {
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // not TFLITE_WITH_RUY
|
#endif // not TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_EIGEN_H_
|
||||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -188,6 +188,6 @@ struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // not TFLITE_WITH_RUY
|
#endif // not TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
|
||||||
|
|
|
@ -42,7 +42,7 @@ inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
|
||||||
|
|
||||||
template <typename Scalar, typename DataPointer>
|
template <typename Scalar, typename DataPointer>
|
||||||
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||||
ruy::Matrix<Scalar>* dst) {
|
ruy::Matrix<Scalar>* dst, bool use_caching = false) {
|
||||||
ruy::Order ruy_order = params.order == Order::kColMajor
|
ruy::Order ruy_order = params.order == Order::kColMajor
|
||||||
? ruy::Order::kColMajor
|
? ruy::Order::kColMajor
|
||||||
: ruy::Order::kRowMajor;
|
: ruy::Order::kRowMajor;
|
||||||
|
@ -52,9 +52,9 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||||
// It does care whether we assign to it a Scalar* or a const Scalar*.
|
// It does care whether we assign to it a Scalar* or a const Scalar*.
|
||||||
dst->set_data(data_ptr);
|
dst->set_data(data_ptr);
|
||||||
dst->set_zero_point(params.zero_point);
|
dst->set_zero_point(params.zero_point);
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
if (use_caching) {
|
||||||
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
|
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GemmParamsType, typename RuySpecType>
|
template <typename GemmParamsType, typename RuySpecType>
|
||||||
|
@ -88,8 +88,8 @@ struct GemmImplUsingRuy {
|
||||||
ruy::Matrix<LhsScalar> ruy_lhs;
|
ruy::Matrix<LhsScalar> ruy_lhs;
|
||||||
ruy::Matrix<RhsScalar> ruy_rhs;
|
ruy::Matrix<RhsScalar> ruy_rhs;
|
||||||
ruy::Matrix<DstScalar> ruy_dst;
|
ruy::Matrix<DstScalar> ruy_dst;
|
||||||
MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
|
MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching());
|
||||||
MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
|
MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching());
|
||||||
MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
|
MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
|
||||||
|
|
||||||
ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
|
ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
|
||||||
|
|
|
@ -363,7 +363,8 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||||
CpuBackendContext cpu_backend_context;
|
CpuBackendContext cpu_backend_context;
|
||||||
std::default_random_engine random_engine;
|
std::default_random_engine random_engine;
|
||||||
cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
|
cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
|
||||||
|
bool use_caching = static_cast<bool>(random_engine() % 2);
|
||||||
|
cpu_backend_context.SetUseCaching(use_caching);
|
||||||
const bool use_golden = !golden.empty();
|
const bool use_golden = !golden.empty();
|
||||||
|
|
||||||
std::vector<LhsScalar> lhs_data;
|
std::vector<LhsScalar> lhs_data;
|
||||||
|
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
|
||||||
#ifdef TFLITE_WITH_RUY
|
#ifdef TFLITE_WITH_RUY_ONLY
|
||||||
#include "ruy/context.h" // from @ruy
|
#include "ruy/context.h" // from @ruy
|
||||||
#include "ruy/thread_pool.h" // from @ruy
|
#include "ruy/thread_pool.h" // from @ruy
|
||||||
#else
|
#else
|
||||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace cpu_backend_threadpool {
|
namespace cpu_backend_threadpool {
|
||||||
|
|
||||||
#ifdef TFLITE_WITH_RUY
|
#ifdef TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
using Task = ruy::Task;
|
using Task = ruy::Task;
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ void Execute(int tasks_count, TaskType* tasks,
|
||||||
tasks_count, tasks);
|
tasks_count, tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // not TFLITE_WITH_RUY
|
#else // not TFLITE_WITH_RUY_ONLY
|
||||||
|
|
||||||
using Task = gemmlowp::Task;
|
using Task = gemmlowp::Task;
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ inline void DepthwiseConv(const DepthwiseParams& params,
|
||||||
int thread_count = HowManyConvThreads(output_shape, filter_shape);
|
int thread_count = HowManyConvThreads(output_shape, filter_shape);
|
||||||
const int max_threads = cpu_backend_context->max_num_threads();
|
const int max_threads = cpu_backend_context->max_num_threads();
|
||||||
thread_count = std::max(1, std::min(thread_count, max_threads));
|
thread_count = std::max(1, std::min(thread_count, max_threads));
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY_ONLY
|
||||||
// Cap the number of threads to 2 for float path to avoid regression in
|
// Cap the number of threads to 2 for float path to avoid regression in
|
||||||
// performance (b/132294857).
|
// performance (b/132294857).
|
||||||
if (std::is_floating_point<T>::value) {
|
if (std::is_floating_point<T>::value) {
|
||||||
|
|
|
@ -144,13 +144,13 @@ cc_library(
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||||
"//tensorflow/lite/profiling:platform_profiler",
|
"//tensorflow/lite/profiling:platform_profiler",
|
||||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||||
"//tensorflow/lite/profiling:profiler",
|
"//tensorflow/lite/profiling:profiler",
|
||||||
"//tensorflow/lite/tools:logging",
|
"//tensorflow/lite/tools:logging",
|
||||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||||
"//tensorflow/lite/tools/delegates:tflite_execution_providers",
|
"//tensorflow/lite/tools/delegates:tflite_execution_providers",
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@ruy//ruy/profiler",
|
"@ruy//ruy/profiler",
|
||||||
|
|
|
@ -34,6 +34,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() {
|
||||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f));
|
params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f));
|
||||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
||||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
||||||
|
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
|
||||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
||||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
||||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
||||||
|
@ -82,6 +83,11 @@ std::vector<Flag> BenchmarkModel::GetFlags() {
|
||||||
"the end of the run but will not start the next run."),
|
"the end of the run but will not start the next run."),
|
||||||
CreateFlag<float>("run_delay", ¶ms_, "delay between runs in seconds"),
|
CreateFlag<float>("run_delay", ¶ms_, "delay between runs in seconds"),
|
||||||
CreateFlag<int32_t>("num_threads", ¶ms_, "number of threads"),
|
CreateFlag<int32_t>("num_threads", ¶ms_, "number of threads"),
|
||||||
|
CreateFlag<bool>(
|
||||||
|
"use_caching", ¶ms_,
|
||||||
|
"Enable caching of prepacked weights matrices in matrix "
|
||||||
|
"multiplication routines. Currently implies the use of the Ruy "
|
||||||
|
"library."),
|
||||||
CreateFlag<std::string>("benchmark_name", ¶ms_, "benchmark name"),
|
CreateFlag<std::string>("benchmark_name", ¶ms_, "benchmark name"),
|
||||||
CreateFlag<std::string>("output_prefix", ¶ms_,
|
CreateFlag<std::string>("output_prefix", ¶ms_,
|
||||||
"benchmark output prefix"),
|
"benchmark output prefix"),
|
||||||
|
@ -108,6 +114,8 @@ void BenchmarkModel::LogParams() {
|
||||||
<< params_.Get<float>("run_delay") << "]";
|
<< params_.Get<float>("run_delay") << "]";
|
||||||
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
|
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
|
||||||
<< "]";
|
<< "]";
|
||||||
|
TFLITE_LOG(INFO) << "Use caching: [" << params_.Get<bool>("use_caching")
|
||||||
|
<< "]";
|
||||||
TFLITE_LOG(INFO) << "Benchmark name: ["
|
TFLITE_LOG(INFO) << "Benchmark name: ["
|
||||||
<< params_.Get<std::string>("benchmark_name") << "]";
|
<< params_.Get<std::string>("benchmark_name") << "]";
|
||||||
TFLITE_LOG(INFO) << "Output prefix: ["
|
TFLITE_LOG(INFO) << "Output prefix: ["
|
||||||
|
|
|
@ -53,6 +53,7 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
||||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
|
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
|
||||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
||||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
||||||
|
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
|
||||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
||||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
||||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
||||||
|
@ -397,6 +398,14 @@ TEST(BenchmarkTest, RunWithWrongFlags) {
|
||||||
EXPECT_EQ(kTfLiteError, status);
|
EXPECT_EQ(kTfLiteError, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BenchmarkTest, RunWithUseCaching) {
|
||||||
|
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
|
||||||
|
TestBenchmark benchmark(CreateFp32Params());
|
||||||
|
ScopedCommandlineArgs scoped_argv({"--use_caching=false"});
|
||||||
|
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
|
||||||
|
EXPECT_EQ(kTfLiteOk, status);
|
||||||
|
}
|
||||||
|
|
||||||
class MaxDurationWorksTestListener : public BenchmarkListener {
|
class MaxDurationWorksTestListener : public BenchmarkListener {
|
||||||
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
||||||
const int64_t num_actual_runs = results.inference_time_us().count();
|
const int64_t num_actual_runs = results.inference_time_us().count();
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "ruy/profiler/profiler.h" // from @ruy
|
#include "ruy/profiler/profiler.h" // from @ruy
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/op_resolver.h"
|
#include "tensorflow/lite/op_resolver.h"
|
||||||
|
@ -600,11 +601,24 @@ TfLiteStatus BenchmarkTfLiteModel::ResetInputsAndOutputs() {
|
||||||
TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
|
TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
|
||||||
auto resolver = GetOpResolver();
|
auto resolver = GetOpResolver();
|
||||||
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
||||||
|
const bool use_caching = params_.Get<bool>("use_caching");
|
||||||
tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
|
tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
|
||||||
if (!interpreter_) {
|
if (!interpreter_) {
|
||||||
TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
|
TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
// Manually enable caching behavior in TF Lite interpreter.
|
||||||
|
if (use_caching) {
|
||||||
|
external_context_.reset(new tflite::ExternalCpuBackendContext());
|
||||||
|
std::unique_ptr<tflite::CpuBackendContext> cpu_backend_context(
|
||||||
|
new tflite::CpuBackendContext());
|
||||||
|
cpu_backend_context->SetUseCaching(true);
|
||||||
|
external_context_->set_internal_backend_context(
|
||||||
|
std::move(cpu_backend_context));
|
||||||
|
interpreter_->SetExternalContext(kTfLiteCpuBackendContext,
|
||||||
|
external_context_.get());
|
||||||
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
||||||
|
|
||||||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
std::unique_ptr<tflite::FlatBufferModel> model_;
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
|
std::unique_ptr<tflite::ExternalCpuBackendContext> external_context_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Implement type erasure with unique_ptr with custom deleter.
|
// Implement type erasure with unique_ptr with custom deleter.
|
||||||
|
|
|
@ -187,7 +187,7 @@ ifeq ($(TARGET_ARCH),aarch64)
|
||||||
BUILD_WITH_RUY=true
|
BUILD_WITH_RUY=true
|
||||||
endif
|
endif
|
||||||
ifeq ($(BUILD_WITH_RUY),true)
|
ifeq ($(BUILD_WITH_RUY),true)
|
||||||
CXXFLAGS += -DTFLITE_WITH_RUY
|
CXXFLAGS += -DTFLITE_WITH_RUY_ONLY
|
||||||
endif
|
endif
|
||||||
|
|
||||||
BUILD_WITH_RUY_PROFILER ?= false
|
BUILD_WITH_RUY_PROFILER ?= false
|
||||||
|
|
Loading…
Reference in New Issue