Change existing call sites of the old deprecated gemmlowp WorkersPool::Execute method, which is a footgun because it destroys the Task object that it takes, to the new more explicit name LegacyExecuteAndDestroyTasks for the same behavior.
PiperOrigin-RevId: 245999142
This commit is contained in:
parent
7b57c5a02c
commit
6e28fe7894
@ -285,8 +285,6 @@ void TrMul(TrMulParams* params, Context* context) {
|
|||||||
allocator->Allocate(1, &atomic_n);
|
allocator->Allocate(1, &atomic_n);
|
||||||
TrMulTask* tasks;
|
TrMulTask* tasks;
|
||||||
allocator->Allocate(thread_count, &tasks);
|
allocator->Allocate(thread_count, &tasks);
|
||||||
Task** tasks_ptrs;
|
|
||||||
allocator->Allocate(thread_count, &tasks_ptrs);
|
|
||||||
|
|
||||||
// Initialize allocated data.
|
// Initialize allocated data.
|
||||||
for (int i = 0; i < num_blocks_of_rows; i++) {
|
for (int i = 0; i < num_blocks_of_rows; i++) {
|
||||||
@ -298,8 +296,7 @@ void TrMul(TrMulParams* params, Context* context) {
|
|||||||
atomic_n->store(thread_count);
|
atomic_n->store(thread_count);
|
||||||
|
|
||||||
for (int i = 0; i < thread_count; i++) {
|
for (int i = 0; i < thread_count; i++) {
|
||||||
tasks_ptrs[i] = static_cast<Task*>(tasks + i);
|
new (tasks + i)
|
||||||
new (tasks_ptrs[i])
|
|
||||||
TrMulTask(params, block_map, atomic_n, i, lhs_packed, rhs_packed,
|
TrMulTask(params, block_map, atomic_n, i, lhs_packed, rhs_packed,
|
||||||
&context->per_thread_states[i]->tuning_resolver,
|
&context->per_thread_states[i]->tuning_resolver,
|
||||||
&context->per_thread_states[i]->allocator, trace);
|
&context->per_thread_states[i]->allocator, trace);
|
||||||
@ -309,7 +306,7 @@ void TrMul(TrMulParams* params, Context* context) {
|
|||||||
TraceRecordExecute(trace);
|
TraceRecordExecute(trace);
|
||||||
TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace);
|
TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace);
|
||||||
|
|
||||||
context->workers_pool.Execute(thread_count, tasks_ptrs);
|
context->workers_pool.Execute(thread_count, tasks);
|
||||||
|
|
||||||
// Finish up.
|
// Finish up.
|
||||||
for (int i = 0; i < thread_count; i++) {
|
for (int i = 0; i < thread_count; i++) {
|
||||||
|
@ -225,17 +225,17 @@ class Thread {
|
|||||||
BlockingCounter* const counter_to_decrement_when_ready_;
|
BlockingCounter* const counter_to_decrement_when_ready_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ThreadPool::Execute(int task_count, Task** tasks_ptrs) {
|
void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
|
||||||
RUY_DCHECK_GE(task_count, 1);
|
RUY_DCHECK_GE(task_count, 1);
|
||||||
// Task #0 will be run on the current thread.
|
// Task #0 will be run on the current thread.
|
||||||
CreateThreads(task_count - 1);
|
CreateThreads(task_count - 1);
|
||||||
counter_to_decrement_when_ready_.Reset(task_count - 1);
|
counter_to_decrement_when_ready_.Reset(task_count - 1);
|
||||||
for (int i = 1; i < task_count; i++) {
|
for (int i = 1; i < task_count; i++) {
|
||||||
threads_[i - 1]->StartWork(tasks_ptrs[i]);
|
auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
|
||||||
|
threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
|
||||||
}
|
}
|
||||||
// Execute task #0 workload immediately on the current thread.
|
// Execute task #0 workload immediately on the current thread.
|
||||||
Task* last_task = tasks_ptrs[0];
|
(tasks + 0)->Run();
|
||||||
last_task->Run();
|
|
||||||
// Wait for the threads submitted above to finish.
|
// Wait for the threads submitted above to finish.
|
||||||
counter_to_decrement_when_ready_.Wait();
|
counter_to_decrement_when_ready_.Wait();
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,13 @@ class ThreadPool {
|
|||||||
// want to run an unbounded number of tasks on a bounded number of threads,
|
// want to run an unbounded number of tasks on a bounded number of threads,
|
||||||
// then you need something higher-level than this ThreadPool, that can
|
// then you need something higher-level than this ThreadPool, that can
|
||||||
// be layered on top of it by appropriately subclassing Tasks.
|
// be layered on top of it by appropriately subclassing Tasks.
|
||||||
void Execute(int task_count, Task** tasks_ptrs);
|
//
|
||||||
|
// TaskType must be a subclass of ruy::Task. That is implicitly guarded by
|
||||||
|
// the static_cast in this inline implementation.
|
||||||
|
template <typename TaskType>
|
||||||
|
void Execute(int task_count, TaskType* tasks) {
|
||||||
|
ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks));
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Ensures that the pool has at least the given count of threads.
|
// Ensures that the pool has at least the given count of threads.
|
||||||
@ -76,6 +82,10 @@ class ThreadPool {
|
|||||||
// be ready.
|
// be ready.
|
||||||
void CreateThreads(int threads_count);
|
void CreateThreads(int threads_count);
|
||||||
|
|
||||||
|
// Non-templatized implementation of the public Execute method.
|
||||||
|
// See the inline implementation of Execute for how this is used.
|
||||||
|
void ExecuteImpl(int task_count, int stride, Task* tasks);
|
||||||
|
|
||||||
// copy construction disallowed
|
// copy construction disallowed
|
||||||
ThreadPool(const ThreadPool&) = delete;
|
ThreadPool(const ThreadPool&) = delete;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user