diff --git a/tensorflow/lite/experimental/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/thread_pool.cc index db69dc8cc94..89957181170 100644 --- a/tensorflow/lite/experimental/ruy/thread_pool.cc +++ b/tensorflow/lite/experimental/ruy/thread_pool.cc @@ -153,17 +153,23 @@ class Thread { void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { RUY_DCHECK_GE(task_count, 1); - // Task #0 will be run on the current thread. - CreateThreads(task_count - 1); - counter_to_decrement_when_ready_.Reset(task_count - 1); - for (int i = 1; i < task_count; i++) { - auto task_address = reinterpret_cast(tasks) + i * stride; - threads_[i - 1]->StartWork(reinterpret_cast(task_address)); + if (task_count > 1) { + // Task #0 will be run on the current thread. + CreateThreads(task_count - 1); + counter_to_decrement_when_ready_.Reset(task_count - 1); + for (int i = 1; i < task_count; i++) { + auto task_address = reinterpret_cast(tasks) + i * stride; + threads_[i - 1]->StartWork(reinterpret_cast(task_address)); + } } - // Execute task #0 workload immediately on the current thread. + + // Execute task #0 immediately on the current thread. (tasks + 0)->Run(); - // Wait for the threads submitted above to finish. - counter_to_decrement_when_ready_.Wait(); + + if (task_count > 1) { + // Wait for the threads submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + } } // Ensures that the pool has at least the given count of threads.