diff --git a/tensorflow/lite/external_cpu_backend_context.h b/tensorflow/lite/external_cpu_backend_context.h index c667057a48c..662734c9cd5 100644 --- a/tensorflow/lite/external_cpu_backend_context.h +++ b/tensorflow/lite/external_cpu_backend_context.h @@ -35,6 +35,8 @@ class TfLiteInternalBackendContext { // TfLite computation. virtual void SetMaxNumThreads(int max_num_threads) = 0; + // A context may internally cache prepacked versions of constant tensors for + // faster computation. This function will clear any caches on the context. virtual void ClearCaches() = 0; }; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index edd6d034a11..aad79ffbc89 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -10,24 +10,27 @@ package( licenses = ["notice"], # Apache 2.0 ) -# Enables usage of ruy in TFLite kernels. +# Enables usage of ruy exclusively as the GEMM backend in TFLite kernels. +# This will cause TFLite to build with ruy only, providing a smaller binary. # WARNING: This build flag is experimental and subject to change. config_setting( - name = "tflite_with_ruy_explicit_true", - define_values = {"tflite_with_ruy": "true"}, + name = "tflite_with_ruy_only_explicit_true", + define_values = {"TFLITE_WITH_RUY_ONLY": "true"}, ) -# Disables usage of ruy in TFLite kernels. +# Disables usage of ruy as the exclusive GEMM backend in TFLite kernels. +# TFLite will be built with ruy and other GEMM libraries. Ruy will not be +# the default GEMM option at runtime. # WARNING: This build flag is experimental and subject to change. config_setting( - name = "tflite_with_ruy_explicit_false", - define_values = {"tflite_with_ruy": "false"}, + name = "tflite_with_ruy_only_explicit_false", + define_values = {"TFLITE_WITH_RUY_ONLY": "false"}, ) ###### Beginning of config_setting's to match aarch64 ###### # # We need to identify the aarch64 instruction set to decide whether to enable -# tflite_with_ruy by default. This is surprisingly hard to do because select() +# TFLITE_WITH_RUY_ONLY by default. This is surprisingly hard to do because select() # can only consume config_setting's, these config_settings are not centralized, # and the "cpu" value which they define are free-form strings and there is no # standardization of the strings that we need to match for the aarch64 architecture. @@ -229,45 +232,45 @@ cc_test( ) cc_library( - name = "tflite_with_ruy_enabled", + name = "tflite_with_ruy_only_enabled", build_for_embedded = True, - defines = ["TFLITE_WITH_RUY"], + defines = ["TFLITE_WITH_RUY_ONLY"], visibility = ["//visibility:private"], ) cc_library( - name = "tflite_with_ruy_and_caching_enabled", + name = "tflite_with_ruy_only_and_caching_enabled", defines = [ - "TFLITE_WITH_RUY", + "TFLITE_WITH_RUY_ONLY", "TFLITE_WITH_RUY_GEMV", ], visibility = ["//visibility:private"], ) cc_library( - name = "tflite_with_ruy_default", + name = "tflite_with_ruy_only_default", build_for_embedded = True, select_deps = { - ":chromiumos_arm64": [":tflite_with_ruy_enabled"], - ":cpu_aarch64": [":tflite_with_ruy_enabled"], - ":cpu_arm64": [":tflite_with_ruy_enabled"], - ":cpu_arm64e": [":tflite_with_ruy_enabled"], - ":cpu_ios_arm64": [":tflite_with_ruy_enabled"], - ":cpu_ios_arm64e": [":tflite_with_ruy_enabled"], - ":cpu_arm64_v8a": [":tflite_with_ruy_enabled"], - "//tensorflow:android_arm": ["tflite_with_ruy_enabled"], + ":chromiumos_arm64": [":tflite_with_ruy_only_enabled"], + ":cpu_aarch64": [":tflite_with_ruy_only_enabled"], + ":cpu_arm64": [":tflite_with_ruy_only_enabled"], + ":cpu_arm64e": [":tflite_with_ruy_only_enabled"], + ":cpu_ios_arm64": [":tflite_with_ruy_only_enabled"], + ":cpu_ios_arm64e": [":tflite_with_ruy_only_enabled"], + ":cpu_arm64_v8a": [":tflite_with_ruy_only_enabled"], + "//tensorflow:android_arm": ["tflite_with_ruy_only_enabled"], "//conditions:default": [], }, visibility = ["//visibility:private"], ) cc_library( - name = "tflite_with_ruy", + name = "tflite_with_ruy_only", build_for_embedded = True, select_deps = { - ":tflite_with_ruy_explicit_true": [":tflite_with_ruy_enabled"], - ":tflite_with_ruy_explicit_false": [], - "//conditions:default": [":tflite_with_ruy_default"], + ":tflite_with_ruy_only_explicit_true": [":tflite_with_ruy_only_enabled"], + ":tflite_with_ruy_only_explicit_false": [], + "//conditions:default": [":tflite_with_ruy_only_default"], }, ) @@ -281,7 +284,7 @@ cc_library( ], copts = tflite_copts(), deps = [ - ":tflite_with_ruy", + ":tflite_with_ruy_only", ":op_macros", # For now this unconditionally depends on both ruy and gemmlowp. # See the comment inside class CpuBackendContext on the @@ -300,11 +303,11 @@ cc_library( copts = tflite_copts(), deps = [ ":cpu_backend_context", - ":tflite_with_ruy", + ":tflite_with_ruy_only", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:types", # For now this unconditionally depends on both ruy and gemmlowp. - # We only need to depend on gemmlowp when tflite_with_ruy + # We only need to depend on gemmlowp when tflite_with_ruy_only # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. "@ruy//ruy:context", @@ -338,19 +341,19 @@ cc_library( ], copts = tflite_copts(), deps = [ - ":tflite_with_ruy", + ":tflite_with_ruy_only", "//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:types", ":cpu_backend_context", ":cpu_backend_threadpool", - # Depend on ruy regardless of `tflite_with_ruy`. See the comment in + # Depend on ruy regardless of `tflite_with_ruy_only`. See the comment in # cpu_backend_gemm.h about why ruy is the generic path. "@ruy//ruy", "@ruy//ruy:matrix", "@ruy//ruy:path", "@ruy//ruy/profiler:instrumentation", - # We only need to depend on gemmlowp and Eigen when tflite_with_ruy + # We only need to depend on gemmlowp and Eigen when tflite_with_ruy_only # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. "@gemmlowp", @@ -580,7 +583,7 @@ cc_library( ], copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], - deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_and_caching_enabled"], + deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_only_and_caching_enabled"], ) cc_library( diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index 21ee5f806a8..1d610b2e068 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/eigen_support.h" // b/131835803 forces us to include multithreaded_conv.h before optimized_ops.h -#ifndef TFLITE_WITH_RUY +#ifndef TFLITE_WITH_RUY_ONLY #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #endif #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" @@ -768,8 +768,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, break; } case kMultithreadOptimized: { -#ifdef TFLITE_WITH_RUY - // See Register_CONV_2D: we should never be here when tflite_with_ruy +#ifdef TFLITE_WITH_RUY_ONLY + // See Register_CONV_2D: we should never be here when TFLITE_WITH_RUY_ONLY // was enabled. We #if out this code in order to get the corresponding // binary size benefits. TFLITE_DCHECK(false); @@ -1054,8 +1054,8 @@ TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() { TfLiteRegistration* Register_CONV_2D() { #if defined TFLITE_USE_APPLE_ACCELERATE_FOR_CONV return Register_CONVOLUTION_CBLAS_OPT(); -#elif defined TFLITE_WITH_RUY - // tflite_with_ruy optimizes the generic kernel type. +#elif defined TFLITE_WITH_RUY_ONLY + // TFLITE_WITH_RUY_ONLY optimizes the generic kernel type. return Register_CONVOLUTION_GENERIC_OPT(); #else return Register_CONVOLUTION_MULTITHREADED_OPT(); @@ -1066,8 +1066,8 @@ TfLiteRegistration* Register_CONV_2D() { // models only need the UINT8 type. TFLite's op registration mechanism doesn't // yet allow for more nuanced registration mechanisms. TfLiteRegistration* Register_CONV_2D_UINT8() { -#if defined TFLITE_WITH_RUY - // tflite_with_ruy optimizes the generic kernel type. +#if defined TFLITE_WITH_RUY_ONLY + // TFLITE_WITH_RUY_ONLY optimizes the generic kernel type. return Register_CONVOLUTION_GENERIC_OPT_UINT8(); #else return Register_CONV_2D(); diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index e6484c4a4c6..ef1d5366255 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -141,7 +141,7 @@ class ConvolutionOpModel : public BaseConvolutionOpModel { const auto kKernelMap = new std::map({ {"Reference", ops::builtin::Register_CONVOLUTION_REF()}, {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, -#ifndef TFLITE_WITH_RUY +#ifndef TFLITE_WITH_RUY_ONLY {"MultithreadedOptimized", ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, #endif diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index d6de9bf8d61..7a16bed0ead 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -55,6 +55,12 @@ CpuBackendContext::CpuBackendContext() ruy_context_(new ruy::Context), gemmlowp_context_(new gemmlowp::GemmContext) { SetMaxNumThreads(kDefaultNumThreadpoolThreads); +// TODO(b/148289189) Remove when clients have transitioned to runtime flag. +#ifdef TFLITE_WITH_RUY_GEMV + SetUseCaching(true); +#else + SetUseCaching(false); +#endif } CpuBackendContext::~CpuBackendContext() {} @@ -67,4 +73,6 @@ void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { gemmlowp_context_->set_max_num_threads(target_num_threads); } +void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; } + } // namespace tflite diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 46abcd5e90f..b4973feb56f 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -43,6 +43,10 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { int max_num_threads() const { return max_num_threads_; } + void SetUseCaching(bool flag); + + bool use_caching() const { return use_caching_; } + void ClearCaches() override { ruy_context_->ClearPrepackedCache(); } private: @@ -51,7 +55,7 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { // (see :cpu_backend_gemm), for now a CpuBackendContext always // stores both a gemmlowp context and a ruy context. // TODO(b/131416458): Once call sites all go through abstractions, - // elide what can be elided based on TFLITE_WITH_RUY. + // elide what can be elided based on TFLITE_WITH_RUY_ONLY. const std::unique_ptr ruy_context_; const std::unique_ptr gemmlowp_context_; @@ -65,6 +69,12 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { // This value also gets propagated to back-ends, where it plays the same // information-only role. int max_num_threads_; + // For matrix muliplications with constants parameters (i.e. weights), we can + // sometimes provide speedups by caching the "prepacked" data, for some + // additional memory cost. This flag permits the user to route all + // CpuBackendGem operations to a library that permits such an optimization + // (currently the Ruy library only). + bool use_caching_; CpuBackendContext(const CpuBackendContext&) = delete; }; diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index 16ccc14557f..8e324c8b515 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" -#ifndef TFLITE_WITH_RUY +#ifndef TFLITE_WITH_RUY_ONLY #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" #endif @@ -41,7 +41,7 @@ template {}; -#ifndef TFLITE_WITH_RUY +#ifndef TFLITE_WITH_RUY_ONLY /* Specializations using gemmlowp */ @@ -81,7 +81,7 @@ template <> struct GemmImpl : detail::GemmImplUsingEigen {}; -#endif // not TFLITE_WITH_RUY +#endif // not TFLITE_WITH_RUY_ONLY /* Public entry point */ @@ -94,12 +94,17 @@ void Gemm(const MatrixParams& lhs_params, const LhsScalar* lhs_data, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ValidateParams(lhs_params, rhs_params, dst_params, params); - bool do_custom_gemv = dst_params.cols == 1; -#ifdef TFLITE_WITH_RUY_GEMV - // Prefer a Ruy GEMM to Custom GEMV unless we are doing float math. - // TODO(b/148692500): Add float GEMV kernels to Ruy. - do_custom_gemv = do_custom_gemv && std::is_floating_point::value; -#endif + if (context->use_caching()) { + // Dispatch to backend that supports caching of prepacked weights + // matrices. + detail::GemmImplUsingRuy::Run(lhs_params, lhs_data, + rhs_params, rhs_data, + dst_params, dst_data, + params, context); + return; + } + const bool do_custom_gemv = (dst_params.cols == 1); if (do_custom_gemv) { // GEMV case: try a custom fast GEMV path. if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h index f85a1715af2..1c3c0ca39c4 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h @@ -586,10 +586,10 @@ struct CustomGemvImpl #include @@ -188,6 +188,6 @@ struct GemmImplUsingGemmlowp void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, - ruy::Matrix* dst) { + ruy::Matrix* dst, bool use_caching = false) { ruy::Order ruy_order = params.order == Order::kColMajor ? ruy::Order::kColMajor : ruy::Order::kRowMajor; @@ -52,9 +52,9 @@ void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, // It does care whether we assign to it a Scalar* or a const Scalar*. dst->set_data(data_ptr); dst->set_zero_point(params.zero_point); -#ifdef TFLITE_WITH_RUY_GEMV - dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); -#endif + if (use_caching) { + dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); + } } template @@ -88,8 +88,8 @@ struct GemmImplUsingRuy { ruy::Matrix ruy_lhs; ruy::Matrix ruy_rhs; ruy::Matrix ruy_dst; - MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs); - MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs); + MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching()); + MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching()); MakeRuyMatrix(dst_params, dst_data, &ruy_dst); ruy::MulParams ruy_mul_params; diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index 110eb3a07ef..20334947dde 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -363,7 +363,8 @@ void TestSomeGemm(int rows, int depth, int cols, CpuBackendContext cpu_backend_context; std::default_random_engine random_engine; cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8)); - + bool use_caching = static_cast(random_engine() % 2); + cpu_backend_context.SetUseCaching(use_caching); const bool use_golden = !golden.empty(); std::vector lhs_data; diff --git a/tensorflow/lite/kernels/cpu_backend_threadpool.h b/tensorflow/lite/kernels/cpu_backend_threadpool.h index 39eafd51d6a..60a5ebfde29 100644 --- a/tensorflow/lite/kernels/cpu_backend_threadpool.h +++ b/tensorflow/lite/kernels/cpu_backend_threadpool.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#ifdef TFLITE_WITH_RUY +#ifdef TFLITE_WITH_RUY_ONLY #include "ruy/context.h" // from @ruy #include "ruy/thread_pool.h" // from @ruy #else @@ -29,7 +29,7 @@ limitations under the License. namespace tflite { namespace cpu_backend_threadpool { -#ifdef TFLITE_WITH_RUY +#ifdef TFLITE_WITH_RUY_ONLY using Task = ruy::Task; @@ -41,7 +41,7 @@ void Execute(int tasks_count, TaskType* tasks, tasks_count, tasks); } -#else // not TFLITE_WITH_RUY +#else // not TFLITE_WITH_RUY_ONLY using Task = gemmlowp::Task; diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h index 7d8838a076e..0e13222b28a 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h @@ -132,7 +132,7 @@ inline void DepthwiseConv(const DepthwiseParams& params, int thread_count = HowManyConvThreads(output_shape, filter_shape); const int max_threads = cpu_backend_context->max_num_threads(); thread_count = std::max(1, std::min(thread_count, max_threads)); -#ifndef TFLITE_WITH_RUY +#ifndef TFLITE_WITH_RUY_ONLY // Cap the number of threads to 2 for float path to avoid regression in // performance (b/132294857). if (std::is_floating_point::value) { diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index f6cb71749f8..c94e1dba7fe 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -144,13 +144,13 @@ cc_library( "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:profile_summary_formatter", "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/tools:logging", "//tensorflow/lite/tools/delegates:delegate_provider_hdr", "//tensorflow/lite/tools/delegates:tflite_execution_providers", - "//tensorflow/lite/tools/evaluation:utils", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@ruy//ruy/profiler", diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc index 4c6fb0eb86e..2a858e7a326 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -34,6 +34,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() { params.AddParam("max_secs", BenchmarkParam::Create(150.0f)); params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("use_caching", BenchmarkParam::Create(false)); params.AddParam("benchmark_name", BenchmarkParam::Create("")); params.AddParam("output_prefix", BenchmarkParam::Create("")); params.AddParam("warmup_runs", BenchmarkParam::Create(1)); @@ -82,6 +83,11 @@ std::vector BenchmarkModel::GetFlags() { "the end of the run but will not start the next run."), CreateFlag("run_delay", ¶ms_, "delay between runs in seconds"), CreateFlag("num_threads", ¶ms_, "number of threads"), + CreateFlag( + "use_caching", ¶ms_, + "Enable caching of prepacked weights matrices in matrix " + "multiplication routines. Currently implies the use of the Ruy " + "library."), CreateFlag("benchmark_name", ¶ms_, "benchmark name"), CreateFlag("output_prefix", ¶ms_, "benchmark output prefix"), @@ -108,6 +114,8 @@ void BenchmarkModel::LogParams() { << params_.Get("run_delay") << "]"; TFLITE_LOG(INFO) << "Num threads: [" << params_.Get("num_threads") << "]"; + TFLITE_LOG(INFO) << "Use caching: [" << params_.Get("use_caching") + << "]"; TFLITE_LOG(INFO) << "Benchmark name: [" << params_.Get("benchmark_name") << "]"; TFLITE_LOG(INFO) << "Output prefix: [" diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index 33ccacc0451..37eddf4faf7 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -53,6 +53,7 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs, params.AddParam("max_secs", BenchmarkParam::Create(max_secs)); params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("use_caching", BenchmarkParam::Create(false)); params.AddParam("benchmark_name", BenchmarkParam::Create("")); params.AddParam("output_prefix", BenchmarkParam::Create("")); params.AddParam("warmup_runs", BenchmarkParam::Create(1)); @@ -397,6 +398,14 @@ TEST(BenchmarkTest, RunWithWrongFlags) { EXPECT_EQ(kTfLiteError, status); } +TEST(BenchmarkTest, RunWithUseCaching) { + ASSERT_THAT(g_fp32_model_path, testing::NotNull()); + TestBenchmark benchmark(CreateFp32Params()); + ScopedCommandlineArgs scoped_argv({"--use_caching=false"}); + auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv()); + EXPECT_EQ(kTfLiteOk, status); +} + class MaxDurationWorksTestListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override { const int64_t num_actual_runs = results.inference_time_us().count(); diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 969713cce73..9114910ad73 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "ruy/profiler/profiler.h" // from @ruy #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" @@ -600,11 +601,24 @@ TfLiteStatus BenchmarkTfLiteModel::ResetInputsAndOutputs() { TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() { auto resolver = GetOpResolver(); const int32_t num_threads = params_.Get("num_threads"); + const bool use_caching = params_.Get("use_caching"); tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads); if (!interpreter_) { TFLITE_LOG(ERROR) << "Failed to initialize the interpreter"; return kTfLiteError; } + // Manually enable caching behavior in TF Lite interpreter. + if (use_caching) { + external_context_.reset(new tflite::ExternalCpuBackendContext()); + std::unique_ptr cpu_backend_context( + new tflite::CpuBackendContext()); + cpu_backend_context->SetUseCaching(true); + external_context_->set_internal_backend_context( + std::move(cpu_backend_context)); + interpreter_->SetExternalContext(kTfLiteCpuBackendContext, + external_context_.get()); + } + return kTfLiteOk; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index cc87743b531..e3307601d73 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -85,6 +85,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::unique_ptr model_; std::unique_ptr interpreter_; + std::unique_ptr external_context_; private: // Implement type erasure with unique_ptr with custom deleter. diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 7d55370818c..e3776f8e056 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -187,7 +187,7 @@ ifeq ($(TARGET_ARCH),aarch64) BUILD_WITH_RUY=true endif ifeq ($(BUILD_WITH_RUY),true) - CXXFLAGS += -DTFLITE_WITH_RUY + CXXFLAGS += -DTFLITE_WITH_RUY_ONLY endif BUILD_WITH_RUY_PROFILER ?= false