Lazily create the Eigen thread pool for TFLite

Currently, the thread pool is created as soon as IncrementUsageCounter
is called. However, that may be called even if the ThreadPool goes unused.
Instead, create the ThreadPool lazily when the ThreadPoolDevice is queried.

PiperOrigin-RevId: 237338488
This commit is contained in:
Jared Duke 2019-03-07 15:29:59 -08:00 committed by TensorFlower Gardener
parent 90b561ef09
commit 57613bcf92
4 changed files with 67 additions and 17 deletions

View File

@ -364,6 +364,18 @@ TEST_P(ConvolutionOpTest, HandCalculatedFloat32) {
// | 187 | 234 | 261 | 121 |
EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357,
178, 187, 234, 261, 121}));
// Add an additional test for the multi-threaded case, ensuring stability
// under different thread counts.
if (GetParam() == "MultithreadedOptimized") {
for (int i = 1; i < 4; ++i) {
m.SetNumThreads(i);
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 178, 187,
234, 261, 121}));
}
}
}
TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {

View File

@ -24,6 +24,10 @@ namespace tflite {
namespace eigen_support {
namespace {
// For legacy reasons, we use 4 threads by default unless the thread count is
// explicitly specified by the context.
const int kDefaultNumThreadpoolThreads = 4;
#ifndef EIGEN_DONT_ALIGN
// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on
// hardware architecture and build configurations.
@ -63,9 +67,45 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
std::unique_ptr<Eigen::ThreadPool> pool_;
};
// Utility class for lazily creating an Eigen thread pool/device only when used.
class LazyEigenThreadPoolHolder {
public:
explicit LazyEigenThreadPoolHolder(int num_threads) {
SetNumThreads(num_threads);
}
// Gets the ThreadPoolDevice, creating if necessary.
const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
if (!device_) {
thread_pool_wrapper_.reset(new EigenThreadPoolWrapper(
new Eigen::ThreadPool(target_num_threads_)));
device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
target_num_threads_));
}
return device_.get();
}
// Updates the thread count, invalidating the ThreadPoolDevice if necessary.
void SetNumThreads(int num_threads) {
const int target_num_threads =
num_threads != -1 ? num_threads : kDefaultNumThreadpoolThreads;
if (target_num_threads_ != target_num_threads) {
target_num_threads_ = target_num_threads;
// As the device references the thread pool wrapper, destroy it first.
device_.reset();
thread_pool_wrapper_.reset();
}
}
private:
int target_num_threads_ = kDefaultNumThreadpoolThreads;
// Both device_ and thread_pool_wrapper_ are lazily created.
std::unique_ptr<Eigen::ThreadPoolDevice> device_;
std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
};
struct RefCountedEigenContext : public TfLiteExternalContext {
std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
int num_references = 0;
};
@ -74,24 +114,12 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
context->GetExternalContext(context, kTfLiteEigenContext));
}
void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
int num_threads = 4;
if (context->recommended_num_threads != -1) {
num_threads = context->recommended_num_threads;
}
ptr->device.reset(); // destroy before we invalidate the thread pool
ptr->thread_pool_wrapper.reset(
new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
ptr->device.reset(
new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
}
TfLiteStatus Refresh(TfLiteContext* context) {
SetEigenNbThreads(context->recommended_num_threads);
auto* ptr = GetEigenContext(context);
if (ptr != nullptr) {
InitDevice(context, ptr);
ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
}
return kTfLiteOk;
@ -108,8 +136,9 @@ void IncrementUsageCounter(TfLiteContext* context) {
ptr = new RefCountedEigenContext;
ptr->type = kTfLiteEigenContext;
ptr->Refresh = Refresh;
ptr->thread_pool_holder.reset(
new LazyEigenThreadPoolHolder(context->recommended_num_threads));
ptr->num_references = 0;
InitDevice(context, ptr);
context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
@ -134,7 +163,7 @@ const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
return ptr->device.get();
return ptr->thread_pool_holder->GetThreadPoolDevice();
}
} // namespace eigen_support

View File

@ -32,6 +32,11 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
// Fetch the ThreadPoolDevice associated with the provided context.
//
// Note: The caller must ensure that |IncrementUsageCounter()| has already been
// called. Moreover, it is *not* safe to cache the returned device; it may be
// invalidated if the context thread count changes.
const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice(
TfLiteContext* context);

View File

@ -326,6 +326,10 @@ class SingleOpModel {
return result;
}
void SetNumThreads(int num_threads) {
interpreter_->SetNumThreads(num_threads);
}
void SetResolver(std::unique_ptr<OpResolver> resolver) {
resolver_ = std::move(resolver);
}