Make TfLiteInternalBackendContext as a interface-only abstract class.

PiperOrigin-RevId: 259455436
This commit is contained in:
Chao Mei 2019-07-22 19:29:01 -07:00 committed by TensorFlower Gardener
parent 95bcd434d0
commit 4f910ac64b
8 changed files with 27 additions and 37 deletions

View File

@ -22,7 +22,7 @@ TfLiteStatus RefreshExternalCpuBackendContext(TfLiteContext* context) {
context->GetExternalContext(context, kTfLiteCpuBackendContext));
if (external_context && external_context->internal_backend_context() &&
context->recommended_num_threads != -1) {
external_context->internal_backend_context()->set_max_num_threads(
external_context->internal_backend_context()->SetMaxNumThreads(
context->recommended_num_threads);
}
return kTfLiteOk;

View File

@ -27,27 +27,13 @@ namespace tflite {
// generally a collection of utilities (i.e. a thread pool etc.) for TF Lite to
// use certain keneral libraries, such as Gemmlowp, RUY, etc., to implement TF
// Lite operators.
// TODO(b/130950871): Make this class as a interface-only abstract class.
class TfLiteInternalBackendContext {
public:
virtual ~TfLiteInternalBackendContext() {}
int max_num_threads() const { return max_num_threads_; }
virtual void set_max_num_threads(int max_num_threads) {
max_num_threads_ = max_num_threads;
}
protected:
TfLiteInternalBackendContext() {}
// The maximum number of threads used for parallelizing TfLite computation.
int max_num_threads_;
private:
TfLiteInternalBackendContext(const TfLiteInternalBackendContext&) = delete;
TfLiteInternalBackendContext& operator=(const TfLiteInternalBackendContext&) =
delete;
// Set the maximum number of threads that could be used for parallelizing
// TfLite computation.
virtual void SetMaxNumThreads(int max_num_threads) = 0;
};
// This TfLiteExternalContext-derived class is the default

View File

@ -24,12 +24,12 @@ CpuBackendContext::CpuBackendContext()
: TfLiteInternalBackendContext(),
ruy_context_(new ruy::Context),
gemmlowp_context_(new gemmlowp::GemmContext) {
set_max_num_threads(1);
SetMaxNumThreads(1);
}
CpuBackendContext::~CpuBackendContext() {}
void CpuBackendContext::set_max_num_threads(int max_num_threads) {
void CpuBackendContext::SetMaxNumThreads(int max_num_threads) {
max_num_threads_ = max_num_threads;
ruy_context_->max_num_threads = max_num_threads;
gemmlowp_context_->set_max_num_threads(max_num_threads);

View File

@ -35,17 +35,11 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
return gemmlowp_context_.get();
}
// Sets the maximum-number-of-threads-to-use parameter.
// This is only a means of passing around this information.
// cpu_backend_threadpool::Execute creates as many threads as it's
// asked to, regardless of this. Typically a call site would query
// cpu_backend_context->max_num_threads() and used that to determine
// the number of tasks to create and to give to
// cpu_backend_threadpool::Execute.
//
// This value also gets propagated to back-ends, where it plays the same
// information-only role.
void set_max_num_threads(int max_num_threads) override;
// Sets the maximum-number-of-threads-to-use parameter, only as a means of
// passing around this information.
void SetMaxNumThreads(int max_num_threads) override;
int max_num_threads() const { return max_num_threads_; }
private:
// To enable a smooth transition from the current direct usage
@ -57,6 +51,17 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
const std::unique_ptr<ruy::Context> ruy_context_;
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
// The maxinum of threads used for parallelizing TfLite ops. However,
// cpu_backend_threadpool::Execute creates as many threads as it's
// asked to, regardless of this. Typically a call site would query
// cpu_backend_context->max_num_threads() and used that to determine
// the number of tasks to create and to give to
// cpu_backend_threadpool::Execute.
//
// This value also gets propagated to back-ends, where it plays the same
// information-only role.
int max_num_threads_;
CpuBackendContext(const CpuBackendContext&) = delete;
};

View File

@ -363,7 +363,7 @@ void TestSomeGemm(int rows, int depth, int cols,
const std::vector<DstScalar>& golden) {
CpuBackendContext cpu_backend_context;
std::default_random_engine random_engine;
cpu_backend_context.set_max_num_threads(1 + (random_engine() % 8));
cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
const bool use_golden = !golden.empty();

View File

@ -46,8 +46,7 @@ CpuBackendContext* GetFromContext(TfLiteContext* context) {
// that's wrapped inside ExternalCpuBackendContext.
cpu_backend_context = new CpuBackendContext();
if (context->recommended_num_threads != -1) {
cpu_backend_context->set_max_num_threads(
context->recommended_num_threads);
cpu_backend_context->SetMaxNumThreads(context->recommended_num_threads);
}
external_context->set_internal_backend_context(
std::unique_ptr<TfLiteInternalBackendContext>(cpu_backend_context));

View File

@ -61,10 +61,10 @@ void TestGenerateArrayOfIncrementingInts(int num_threads, int size) {
ASSERT_EQ(num_threads, tasks.size());
CpuBackendContext context;
// This set_max_num_threads is only to satisfy an assertion in Execute.
// This SetMaxNumThreads is only to satisfy an assertion in Execute.
// What actually determines the number of threads used is the parameter
// passed to Execute, since Execute does 1:1 mapping of tasks to threads.
context.set_max_num_threads(num_threads);
context.SetMaxNumThreads(num_threads);
// Execute tasks on the threadpool.
cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), &context);

View File

@ -292,7 +292,7 @@ inline void DispatchDepthwiseConv(
<< " input_offset = " << params.input_offset;
CpuBackendContext backend_context;
backend_context.set_max_num_threads(test_param.num_threads);
backend_context.SetMaxNumThreads(test_param.num_threads);
optimized_ops::DepthwiseConv<uint8, int32>(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data, &backend_context);