Make TfLiteInternalBackendContext as a interface-only abstract class.
PiperOrigin-RevId: 259455436
This commit is contained in:
parent
95bcd434d0
commit
4f910ac64b
@ -22,7 +22,7 @@ TfLiteStatus RefreshExternalCpuBackendContext(TfLiteContext* context) {
|
||||
context->GetExternalContext(context, kTfLiteCpuBackendContext));
|
||||
if (external_context && external_context->internal_backend_context() &&
|
||||
context->recommended_num_threads != -1) {
|
||||
external_context->internal_backend_context()->set_max_num_threads(
|
||||
external_context->internal_backend_context()->SetMaxNumThreads(
|
||||
context->recommended_num_threads);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -27,27 +27,13 @@ namespace tflite {
|
||||
// generally a collection of utilities (i.e. a thread pool etc.) for TF Lite to
|
||||
// use certain keneral libraries, such as Gemmlowp, RUY, etc., to implement TF
|
||||
// Lite operators.
|
||||
// TODO(b/130950871): Make this class as a interface-only abstract class.
|
||||
class TfLiteInternalBackendContext {
|
||||
public:
|
||||
virtual ~TfLiteInternalBackendContext() {}
|
||||
|
||||
int max_num_threads() const { return max_num_threads_; }
|
||||
|
||||
virtual void set_max_num_threads(int max_num_threads) {
|
||||
max_num_threads_ = max_num_threads;
|
||||
}
|
||||
|
||||
protected:
|
||||
TfLiteInternalBackendContext() {}
|
||||
|
||||
// The maximum number of threads used for parallelizing TfLite computation.
|
||||
int max_num_threads_;
|
||||
|
||||
private:
|
||||
TfLiteInternalBackendContext(const TfLiteInternalBackendContext&) = delete;
|
||||
TfLiteInternalBackendContext& operator=(const TfLiteInternalBackendContext&) =
|
||||
delete;
|
||||
// Set the maximum number of threads that could be used for parallelizing
|
||||
// TfLite computation.
|
||||
virtual void SetMaxNumThreads(int max_num_threads) = 0;
|
||||
};
|
||||
|
||||
// This TfLiteExternalContext-derived class is the default
|
||||
|
@ -24,12 +24,12 @@ CpuBackendContext::CpuBackendContext()
|
||||
: TfLiteInternalBackendContext(),
|
||||
ruy_context_(new ruy::Context),
|
||||
gemmlowp_context_(new gemmlowp::GemmContext) {
|
||||
set_max_num_threads(1);
|
||||
SetMaxNumThreads(1);
|
||||
}
|
||||
|
||||
CpuBackendContext::~CpuBackendContext() {}
|
||||
|
||||
void CpuBackendContext::set_max_num_threads(int max_num_threads) {
|
||||
void CpuBackendContext::SetMaxNumThreads(int max_num_threads) {
|
||||
max_num_threads_ = max_num_threads;
|
||||
ruy_context_->max_num_threads = max_num_threads;
|
||||
gemmlowp_context_->set_max_num_threads(max_num_threads);
|
||||
|
@ -35,17 +35,11 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||
return gemmlowp_context_.get();
|
||||
}
|
||||
|
||||
// Sets the maximum-number-of-threads-to-use parameter.
|
||||
// This is only a means of passing around this information.
|
||||
// cpu_backend_threadpool::Execute creates as many threads as it's
|
||||
// asked to, regardless of this. Typically a call site would query
|
||||
// cpu_backend_context->max_num_threads() and used that to determine
|
||||
// the number of tasks to create and to give to
|
||||
// cpu_backend_threadpool::Execute.
|
||||
//
|
||||
// This value also gets propagated to back-ends, where it plays the same
|
||||
// information-only role.
|
||||
void set_max_num_threads(int max_num_threads) override;
|
||||
// Sets the maximum-number-of-threads-to-use parameter, only as a means of
|
||||
// passing around this information.
|
||||
void SetMaxNumThreads(int max_num_threads) override;
|
||||
|
||||
int max_num_threads() const { return max_num_threads_; }
|
||||
|
||||
private:
|
||||
// To enable a smooth transition from the current direct usage
|
||||
@ -57,6 +51,17 @@ class CpuBackendContext final : public TfLiteInternalBackendContext {
|
||||
const std::unique_ptr<ruy::Context> ruy_context_;
|
||||
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
|
||||
|
||||
// The maxinum of threads used for parallelizing TfLite ops. However,
|
||||
// cpu_backend_threadpool::Execute creates as many threads as it's
|
||||
// asked to, regardless of this. Typically a call site would query
|
||||
// cpu_backend_context->max_num_threads() and used that to determine
|
||||
// the number of tasks to create and to give to
|
||||
// cpu_backend_threadpool::Execute.
|
||||
//
|
||||
// This value also gets propagated to back-ends, where it plays the same
|
||||
// information-only role.
|
||||
int max_num_threads_;
|
||||
|
||||
CpuBackendContext(const CpuBackendContext&) = delete;
|
||||
};
|
||||
|
||||
|
@ -363,7 +363,7 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||
const std::vector<DstScalar>& golden) {
|
||||
CpuBackendContext cpu_backend_context;
|
||||
std::default_random_engine random_engine;
|
||||
cpu_backend_context.set_max_num_threads(1 + (random_engine() % 8));
|
||||
cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
|
||||
|
||||
const bool use_golden = !golden.empty();
|
||||
|
||||
|
@ -46,8 +46,7 @@ CpuBackendContext* GetFromContext(TfLiteContext* context) {
|
||||
// that's wrapped inside ExternalCpuBackendContext.
|
||||
cpu_backend_context = new CpuBackendContext();
|
||||
if (context->recommended_num_threads != -1) {
|
||||
cpu_backend_context->set_max_num_threads(
|
||||
context->recommended_num_threads);
|
||||
cpu_backend_context->SetMaxNumThreads(context->recommended_num_threads);
|
||||
}
|
||||
external_context->set_internal_backend_context(
|
||||
std::unique_ptr<TfLiteInternalBackendContext>(cpu_backend_context));
|
||||
|
@ -61,10 +61,10 @@ void TestGenerateArrayOfIncrementingInts(int num_threads, int size) {
|
||||
ASSERT_EQ(num_threads, tasks.size());
|
||||
|
||||
CpuBackendContext context;
|
||||
// This set_max_num_threads is only to satisfy an assertion in Execute.
|
||||
// This SetMaxNumThreads is only to satisfy an assertion in Execute.
|
||||
// What actually determines the number of threads used is the parameter
|
||||
// passed to Execute, since Execute does 1:1 mapping of tasks to threads.
|
||||
context.set_max_num_threads(num_threads);
|
||||
context.SetMaxNumThreads(num_threads);
|
||||
|
||||
// Execute tasks on the threadpool.
|
||||
cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), &context);
|
||||
|
@ -292,7 +292,7 @@ inline void DispatchDepthwiseConv(
|
||||
<< " input_offset = " << params.input_offset;
|
||||
|
||||
CpuBackendContext backend_context;
|
||||
backend_context.set_max_num_threads(test_param.num_threads);
|
||||
backend_context.SetMaxNumThreads(test_param.num_threads);
|
||||
optimized_ops::DepthwiseConv<uint8, int32>(
|
||||
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
||||
bias_data, output_shape, output_data, &backend_context);
|
||||
|
Loading…
Reference in New Issue
Block a user