From 4f910ac64bc80e430ca2c936de88f107e098cf4e Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Mon, 22 Jul 2019 19:29:01 -0700 Subject: [PATCH] Make TfLiteInternalBackendContext as a interface-only abstract class. PiperOrigin-RevId: 259455436 --- .../lite/external_cpu_backend_context.cc | 2 +- .../lite/external_cpu_backend_context.h | 20 +++----------- .../lite/kernels/cpu_backend_context.cc | 4 +-- tensorflow/lite/kernels/cpu_backend_context.h | 27 +++++++++++-------- .../lite/kernels/cpu_backend_gemm_test.cc | 2 +- .../lite/kernels/cpu_backend_support.cc | 3 +-- .../kernels/cpu_backend_threadpool_test.cc | 4 +-- .../internal/depthwiseconv_quantized_test.cc | 2 +- 8 files changed, 27 insertions(+), 37 deletions(-) diff --git a/tensorflow/lite/external_cpu_backend_context.cc b/tensorflow/lite/external_cpu_backend_context.cc index 2be35c8baf7..df1fc01b8b9 100644 --- a/tensorflow/lite/external_cpu_backend_context.cc +++ b/tensorflow/lite/external_cpu_backend_context.cc @@ -22,7 +22,7 @@ TfLiteStatus RefreshExternalCpuBackendContext(TfLiteContext* context) { context->GetExternalContext(context, kTfLiteCpuBackendContext)); if (external_context && external_context->internal_backend_context() && context->recommended_num_threads != -1) { - external_context->internal_backend_context()->set_max_num_threads( + external_context->internal_backend_context()->SetMaxNumThreads( context->recommended_num_threads); } return kTfLiteOk; diff --git a/tensorflow/lite/external_cpu_backend_context.h b/tensorflow/lite/external_cpu_backend_context.h index 0d8763532c7..8d5125dec1f 100644 --- a/tensorflow/lite/external_cpu_backend_context.h +++ b/tensorflow/lite/external_cpu_backend_context.h @@ -27,27 +27,13 @@ namespace tflite { // generally a collection of utilities (i.e. a thread pool etc.) for TF Lite to // use certain keneral libraries, such as Gemmlowp, RUY, etc., to implement TF // Lite operators. -// TODO(b/130950871): Make this class as a interface-only abstract class. class TfLiteInternalBackendContext { public: virtual ~TfLiteInternalBackendContext() {} - int max_num_threads() const { return max_num_threads_; } - - virtual void set_max_num_threads(int max_num_threads) { - max_num_threads_ = max_num_threads; - } - - protected: - TfLiteInternalBackendContext() {} - - // The maximum number of threads used for parallelizing TfLite computation. - int max_num_threads_; - - private: - TfLiteInternalBackendContext(const TfLiteInternalBackendContext&) = delete; - TfLiteInternalBackendContext& operator=(const TfLiteInternalBackendContext&) = - delete; + // Set the maximum number of threads that could be used for parallelizing + // TfLite computation. + virtual void SetMaxNumThreads(int max_num_threads) = 0; }; // This TfLiteExternalContext-derived class is the default diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index f9a1ee0a86b..63f12208630 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -24,12 +24,12 @@ CpuBackendContext::CpuBackendContext() : TfLiteInternalBackendContext(), ruy_context_(new ruy::Context), gemmlowp_context_(new gemmlowp::GemmContext) { - set_max_num_threads(1); + SetMaxNumThreads(1); } CpuBackendContext::~CpuBackendContext() {} -void CpuBackendContext::set_max_num_threads(int max_num_threads) { +void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { max_num_threads_ = max_num_threads; ruy_context_->max_num_threads = max_num_threads; gemmlowp_context_->set_max_num_threads(max_num_threads); diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 00b12d8ba54..a55a951ac99 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -35,17 +35,11 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { return gemmlowp_context_.get(); } - // Sets the maximum-number-of-threads-to-use parameter. - // This is only a means of passing around this information. - // cpu_backend_threadpool::Execute creates as many threads as it's - // asked to, regardless of this. Typically a call site would query - // cpu_backend_context->max_num_threads() and used that to determine - // the number of tasks to create and to give to - // cpu_backend_threadpool::Execute. - // - // This value also gets propagated to back-ends, where it plays the same - // information-only role. - void set_max_num_threads(int max_num_threads) override; + // Sets the maximum-number-of-threads-to-use parameter, only as a means of + // passing around this information. + void SetMaxNumThreads(int max_num_threads) override; + + int max_num_threads() const { return max_num_threads_; } private: // To enable a smooth transition from the current direct usage @@ -57,6 +51,17 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { const std::unique_ptr ruy_context_; const std::unique_ptr gemmlowp_context_; + // The maxinum of threads used for parallelizing TfLite ops. However, + // cpu_backend_threadpool::Execute creates as many threads as it's + // asked to, regardless of this. Typically a call site would query + // cpu_backend_context->max_num_threads() and used that to determine + // the number of tasks to create and to give to + // cpu_backend_threadpool::Execute. + // + // This value also gets propagated to back-ends, where it plays the same + // information-only role. + int max_num_threads_; + CpuBackendContext(const CpuBackendContext&) = delete; }; diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index c193d1b60cc..fe2792b88cd 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -363,7 +363,7 @@ void TestSomeGemm(int rows, int depth, int cols, const std::vector& golden) { CpuBackendContext cpu_backend_context; std::default_random_engine random_engine; - cpu_backend_context.set_max_num_threads(1 + (random_engine() % 8)); + cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8)); const bool use_golden = !golden.empty(); diff --git a/tensorflow/lite/kernels/cpu_backend_support.cc b/tensorflow/lite/kernels/cpu_backend_support.cc index 64a41b2e1ec..ab47d5b7e99 100644 --- a/tensorflow/lite/kernels/cpu_backend_support.cc +++ b/tensorflow/lite/kernels/cpu_backend_support.cc @@ -46,8 +46,7 @@ CpuBackendContext* GetFromContext(TfLiteContext* context) { // that's wrapped inside ExternalCpuBackendContext. cpu_backend_context = new CpuBackendContext(); if (context->recommended_num_threads != -1) { - cpu_backend_context->set_max_num_threads( - context->recommended_num_threads); + cpu_backend_context->SetMaxNumThreads(context->recommended_num_threads); } external_context->set_internal_backend_context( std::unique_ptr(cpu_backend_context)); diff --git a/tensorflow/lite/kernels/cpu_backend_threadpool_test.cc b/tensorflow/lite/kernels/cpu_backend_threadpool_test.cc index 45208a383c5..5089323070a 100644 --- a/tensorflow/lite/kernels/cpu_backend_threadpool_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_threadpool_test.cc @@ -61,10 +61,10 @@ void TestGenerateArrayOfIncrementingInts(int num_threads, int size) { ASSERT_EQ(num_threads, tasks.size()); CpuBackendContext context; - // This set_max_num_threads is only to satisfy an assertion in Execute. + // This SetMaxNumThreads is only to satisfy an assertion in Execute. // What actually determines the number of threads used is the parameter // passed to Execute, since Execute does 1:1 mapping of tasks to threads. - context.set_max_num_threads(num_threads); + context.SetMaxNumThreads(num_threads); // Execute tasks on the threadpool. cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), &context); diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index fd5b89eaf73..1c3d0e9ad62 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -292,7 +292,7 @@ inline void DispatchDepthwiseConv( << " input_offset = " << params.input_offset; CpuBackendContext backend_context; - backend_context.set_max_num_threads(test_param.num_threads); + backend_context.SetMaxNumThreads(test_param.num_threads); optimized_ops::DepthwiseConv( params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data, &backend_context);