Refactor WaitForVariableChange: abstract away the atomic operations from it

by putting them behind a `condition` function that it waits for, rename it
`WaitUntil`, move it to its own library with own unit-test.

A design benefit from this change is that now the atomic operations are
grouped on the caller's side, which makes it easier to verify that they
are correct (i.e. that 'release' stores match 'acquire' loads).

Another design benefit is that this WaitUntil function now handles all of the spin-waiting that we do in ruy, including the one in BlockingCounter which was using separate code. That code had its own challenge as, not depending on a condition variable, it was occasionally reverting to just waiting for fixed durations to give other threads a chance to get scheduled, which could harm performance. That concern is removed by this change making it use WaitUntil. Making WaitUntil take a std::function for the wait condition was needed to make it general enough for this purpose: the original usage in ThreadPool needs to wait until a value is *different* from a given one, but the usage in BlockingCounter needs to wait until a value is *equal* to a given one.

An overload of WaitUntil allows controlling the spin_duration, in preparation for future changes exposing this setting in the ruy API itself.

PiperOrigin-RevId: 254184689
This commit is contained in:
Benoit Jacob 2019-06-20 06:29:23 -07:00 committed by TensorFlower Gardener
parent 9f06292332
commit 9385325594
7 changed files with 303 additions and 127 deletions

View File

@ -28,6 +28,22 @@ cc_library(
hdrs = ["time.h"],
)
cc_library(
name = "wait",
srcs = ["wait.cc"],
hdrs = ["wait.h"],
deps = [":time"],
)
cc_test(
name = "wait_test",
srcs = ["wait_test.cc"],
deps = [
":wait",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "spec",
hdrs = ["spec.h"],
@ -119,7 +135,7 @@ cc_library(
],
deps = [
":check_macros",
":time",
":wait",
],
)
@ -135,7 +151,7 @@ cc_library(
deps = [
":blocking_counter",
":check_macros",
":time",
":wait",
],
)

View File

@ -15,62 +15,38 @@ limitations under the License.
#include "tensorflow/lite/experimental/ruy/blocking_counter.h"
#include <chrono>
#include <thread>
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
#include "tensorflow/lite/experimental/ruy/check_macros.h"
#include "tensorflow/lite/experimental/ruy/time.h"
#include "tensorflow/lite/experimental/ruy/wait.h"
namespace ruy {
static constexpr double kBlockingCounterMaxBusyWaitSeconds = 2e-3;
void BlockingCounter::Reset(std::size_t initial_count) {
std::size_t old_count_value = count_.load(std::memory_order_relaxed);
void BlockingCounter::Reset(int initial_count) {
int old_count_value = count_.load(std::memory_order_relaxed);
RUY_DCHECK_EQ(old_count_value, 0);
(void)old_count_value;
count_.store(initial_count, std::memory_order_release);
}
bool BlockingCounter::DecrementCount() {
std::size_t old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel);
int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel);
RUY_DCHECK_GT(old_count_value, 0);
std::size_t count_value = old_count_value - 1;
return count_value == 0;
int count_value = old_count_value - 1;
bool hit_zero = (count_value == 0);
if (hit_zero) {
std::lock_guard<std::mutex> lock(count_mutex_);
count_cond_.notify_all();
}
return hit_zero;
}
void BlockingCounter::Wait() {
// Busy-wait until the count value is 0.
const Duration wait_duration =
DurationFromSeconds(kBlockingCounterMaxBusyWaitSeconds);
TimePoint wait_start = Clock::now();
while (count_.load(std::memory_order_acquire)) {
if (Clock::now() - wait_start > wait_duration) {
// If we are unlucky, the blocking thread (that calls DecrementCount)
// and the blocked thread (here, calling Wait) may be scheduled on
// the same CPU, so the busy-waiting of the present thread may prevent
// the blocking thread from resuming and unblocking.
// If we are even unluckier, the priorities of the present thread
// might be higher than that of the blocking thread, so just yielding
// wouldn't allow the blocking thread to resume. So we sleep for
// a substantial amount of time in that case. Notice that we only
// do so after having busy-waited for kBlockingCounterMaxBusyWaitSeconds,
// which is typically >= 1 millisecond, so sleeping 1 more millisecond
// isn't terrible at that point.
//
// How this is mitigated in practice:
// In practice, it is well known that the application should be
// conservative in choosing how many threads to tell gemmlowp to use,
// as it's hard to know how many CPU cores it will get to run on,
// on typical mobile devices.
// It seems impossible for gemmlowp to make this choice automatically,
// which is why gemmlowp's default is to use only 1 thread, and
// applications may override that if they know that they can count on
// using more than that.
std::this_thread::sleep_for(DurationFromSeconds(1e-3));
wait_start = Clock::now();
}
}
const auto& condition = [this]() {
return count_.load(std::memory_order_acquire) == 0;
};
WaitUntil(condition, &count_cond_, &count_mutex_);
}
} // namespace ruy

View File

@ -17,7 +17,8 @@ limitations under the License.
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_
#include <atomic>
#include <cstdint>
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
namespace ruy {
@ -35,7 +36,7 @@ class BlockingCounter {
// Sets/resets the counter; initial_count is the number of
// decrementing events that the Wait() call will be waiting for.
void Reset(std::size_t initial_count);
void Reset(int initial_count);
// Decrements the counter; if the counter hits zero, signals
// the threads that were waiting for that, and returns true.
@ -48,7 +49,12 @@ class BlockingCounter {
void Wait();
private:
std::atomic<std::size_t> count_;
std::atomic<int> count_;
// The condition variable and mutex allowing to passively wait for count_
// to reach the value zero, in the case of longer waits.
std::condition_variable count_cond_;
std::mutex count_mutex_;
};
} // namespace ruy

View File

@ -23,85 +23,10 @@ limitations under the License.
#include "tensorflow/lite/experimental/ruy/blocking_counter.h"
#include "tensorflow/lite/experimental/ruy/check_macros.h"
#include "tensorflow/lite/experimental/ruy/time.h"
#include "tensorflow/lite/experimental/ruy/wait.h"
namespace ruy {
// This value was empirically derived on an end-to-end application benchmark.
// That this value means that we may be sleeping substantially longer
// than a scheduler timeslice's duration is not necessarily surprising. The
// idea is to pick up quickly new work after having finished the previous
// workload. When it's new work within the same GEMM as the previous work, the
// time interval that we might be busy-waiting is very small, so for that
// purpose it would be more than enough to sleep for 1 ms.
// That is all what we would observe on a GEMM benchmark. However, in a real
// application, after having finished a GEMM, we might do unrelated work for
// a little while, then start on a new GEMM. Think of a neural network
// application performing inference, where many but not all layers are
// implemented by a GEMM. In such cases, our worker threads might be idle for
// longer periods of time before having work again. If we let them passively
// wait, on a mobile device, the CPU scheduler might aggressively clock down
// or even turn off the CPU cores that they were running on. That would result
// in a long delay the next time these need to be turned back on for the next
// GEMM. So we need to strike a balance that reflects typical time intervals
// between consecutive GEMM invokations, not just intra-GEMM considerations.
// Of course, we need to balance keeping CPUs spinning longer to resume work
// faster, versus passively waiting to conserve power.
static constexpr double kThreadPoolMaxBusyWaitSeconds = 2e-3;
// Waits until *var != initial_value.
//
// Returns the new value of *var. The guarantee here is that
// the return value is different from initial_value, and that that
// new value has been taken by *var at some point during the
// execution of this function. There is no guarantee that this is
// still the value of *var when this function returns, since *var is
// not assumed to be guarded by any lock.
//
// First does some busy-waiting for a fixed number of no-op cycles,
// then falls back to passive waiting for the given condvar, guarded
// by the given mutex.
//
// The idea of doing some initial busy-waiting is to help get
// better and more consistent multithreading benefits for small GEMM sizes.
// Busy-waiting help ensuring that if we need to wake up soon after having
// started waiting, then we can wake up quickly (as opposed to, say,
// having to wait to be scheduled again by the OS). On the other hand,
// we must still eventually revert to passive waiting for longer waits
// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
// so as to avoid permanently spinning.
//
template <typename T>
T WaitForVariableChange(std::atomic<T>* var, T initial_value,
std::condition_variable* cond, std::mutex* mutex) {
// First, trivial case where the variable already changed value.
T new_value = var->load(std::memory_order_acquire);
if (new_value != initial_value) {
return new_value;
}
// Then try busy-waiting.
const Duration wait_duration =
DurationFromSeconds(kThreadPoolMaxBusyWaitSeconds);
const TimePoint wait_start = Clock::now();
while (Clock::now() - wait_start < wait_duration) {
new_value = var->load(std::memory_order_acquire);
if (new_value != initial_value) {
return new_value;
}
}
// Finally, do real passive waiting.
mutex->lock();
new_value = var->load(std::memory_order_acquire);
while (new_value == initial_value) {
std::unique_lock<std::mutex> lock(*mutex, std::adopt_lock);
cond->wait(lock);
lock.release();
new_value = var->load(std::memory_order_acquire);
}
mutex->unlock();
return new_value;
}
// A worker thread.
class Thread {
public:
@ -184,14 +109,15 @@ class Thread {
// Thread main loop
while (true) {
// Get a state to act on
// In the 'Ready' state, we have nothing to do but to wait until
// we switch to another state.
State state_to_act_upon = WaitForVariableChange(
&state_, State::Ready, &state_cond_, &state_mutex_);
const auto& condition = [this]() {
return state_.load(std::memory_order_acquire) != State::Ready;
};
WaitUntil(condition, &state_cond_, &state_mutex_);
// We now have a state to act on, so act.
switch (state_to_act_upon) {
// Act on new state.
switch (state_.load(std::memory_order_acquire)) {
case State::HasWork:
// Got work to do! So do it, and then revert to 'Ready' state.
ChangeState(State::Ready);

View File

@ -0,0 +1,82 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/ruy/wait.h"
#include <condition_variable> // NOLINT(build/c++11)
#include <functional>
#include <mutex> // NOLINT(build/c++11)
#include "tensorflow/lite/experimental/ruy/time.h"
namespace ruy {
void WaitUntil(const std::function<bool()>& condition,
const Duration& spin_duration, std::condition_variable* condvar,
std::mutex* mutex) {
// First, trivial case where the `condition` is already true;
if (condition()) {
return;
}
// Then try busy-waiting.
const TimePoint wait_start = Clock::now();
while (Clock::now() - wait_start < spin_duration) {
if (condition()) {
return;
}
}
// Finally, do real passive waiting.
//
// TODO(b/135624397): We really want wait_until(TimePoint::max()) but that
// runs into a libc++ bug at the moment, see b/135624397 and
// https://bugs.llvm.org/show_bug.cgi?id=21395#c5. We pick a duration large
// enough to appear infinite in practice and small enough to avoid such
// overflow bugs...
const Duration& timeout = DurationFromSeconds(1e6);
std::unique_lock<std::mutex> lock(*mutex);
condvar->wait_for(lock, timeout, condition);
}
void WaitUntil(const std::function<bool()>& condition,
std::condition_variable* condvar, std::mutex* mutex) {
// This value was empirically derived with some microbenchmark, we don't have
// high confidence in it.
//
// TODO(b/135595069): make this value configurable at runtime.
// I almost wanted to file another bug to ask for experimenting in a more
// principled way to tune this value better, but this would have to be tuned
// on real end-to-end applications and we'd expect different applications to
// require different tunings. So the more important point is the need for
// this to be controllable by the application.
//
// That this value means that we may be sleeping substantially longer
// than a scheduler timeslice's duration is not necessarily surprising. The
// idea is to pick up quickly new work after having finished the previous
// workload. When it's new work within the same GEMM as the previous work, the
// time interval that we might be busy-waiting is very small, so for that
// purpose it would be more than enough to sleep for 1 ms.
// That is all what we would observe on a GEMM benchmark. However, in a real
// application, after having finished a GEMM, we might do unrelated work for
// a little while, then start on a new GEMM. In that case the wait interval
// may be a little longer. There may also not be another GEMM for a long time,
// in which case we'll end up passively waiting below.
const double kMaxBusyWaitSeconds = 2e-3;
const Duration spin_duration = DurationFromSeconds(kMaxBusyWaitSeconds);
WaitUntil(condition, spin_duration, condvar, mutex);
}
} // namespace ruy

View File

@ -0,0 +1,74 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_
#include <condition_variable> // NOLINT(build/c++11)
#include <functional>
#include <mutex> // NOLINT(build/c++11)
#include "tensorflow/lite/experimental/ruy/time.h"
namespace ruy {
// Waits until some evaluation of `condition` has returned true.
//
// There is no guarantee that calling `condition` again after this function
// has returned would still return true. The only
// contract is that at some point during the execution of that function,
// `condition` has returned true.
//
// First does some spin-waiting for the specified `spin_duration`,
// then falls back to passive waiting for the given condvar, guarded
// by the given mutex. At this point it will try to acquire the mutex lock,
// around the waiting on the condition variable.
// Therefore, this function expects that the calling thread hasn't already
// locked the mutex before calling it.
// This function will always release the mutex lock before returning.
//
// The idea of doing some initial spin-waiting is to help get
// better and more consistent multithreading benefits for small GEMM sizes.
// Spin-waiting help ensuring that if we need to wake up soon after having
// started waiting, then we can wake up quickly (as opposed to, say,
// having to wait to be scheduled again by the OS). On the other hand,
// we must still eventually revert to passive waiting for longer waits
// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
// so as to avoid permanently spinning.
//
// In situations where other threads might have more useful things to do with
// these CPU cores than our spin-waiting, it may be best to reduce the value
// of `spin_duration`. Setting it to zero disables the spin-waiting entirely.
//
// There is a risk that the std::function used here might use a heap allocation
// to store its context. The expected usage pattern is that these functions'
// contexts will consist of a single pointer value (typically capturing only
// [this]), and that in this case the std::function implementation will use
// inline storage, avoiding a heap allocation. However, we can't effectively
// guard that assumption, and that's not a big concern anyway because the
// latency of a small heap allocation is probably low compared to the intrinsic
// latency of what this WaitUntil function does.
void WaitUntil(const std::function<bool()>& condition,
const Duration& spin_duration, std::condition_variable* condvar,
std::mutex* mutex);
// Convenience overload using a default `spin_duration`.
// TODO(benoitjacob): let this be controlled from the ruy API.
void WaitUntil(const std::function<bool()>& condition,
std::condition_variable* condvar, std::mutex* mutex);
} // namespace ruy
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_

View File

@ -0,0 +1,96 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/ruy/wait.h"
#include <atomic>
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
#include <thread> // NOLINT(build/c++11)
#include <gtest/gtest.h>
namespace ruy {
namespace {
// Thread taking a `value` atomic counter and incrementing it until it equals
// `end_value`, then notifying the condition variable as long as
// `value == end_value`. If `end_value` is increased, it will then resume
// incrementing `value`, etc. Terminates if `end_value == -1`.
class ThreadCountingUpToValue {
public:
ThreadCountingUpToValue(const std::atomic<int>& end_value,
std::atomic<int>* value,
std::condition_variable* condvar, std::mutex* mutex)
: end_value_(end_value),
value_(value),
condvar_(condvar),
mutex_(mutex) {}
void operator()() {
while (end_value_.load() != -1) {
while (value_->fetch_add(1) < end_value_.load() - 1) {
}
while (value_->load() == end_value_.load()) {
std::lock_guard<std::mutex> lock(*mutex_);
condvar_->notify_all();
}
}
}
private:
const std::atomic<int>& end_value_;
std::atomic<int>* value_;
std::condition_variable* condvar_;
std::mutex* mutex_;
};
void WaitTest(const Duration& spin_duration) {
std::condition_variable condvar;
std::mutex mutex;
std::atomic<int> value(0);
std::atomic<int> end_value(0);
ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex);
std::thread thread(thread_callable);
for (int i = 1; i < 10; i++) {
end_value.store(1000 * i);
const auto& condition = [&value, &end_value]() {
return value.load() == end_value.load();
};
ruy::WaitUntil(condition, spin_duration, &condvar, &mutex);
EXPECT_EQ(value.load(), end_value.load());
}
end_value.store(-1);
thread.join();
}
TEST(WaitTest, WaitTestNoSpin) { WaitTest(DurationFromSeconds(0)); }
TEST(WaitTest, WaitTestSpinOneMicrosecond) {
WaitTest(DurationFromSeconds(1e-6));
}
TEST(WaitTest, WaitTestSpinOneMillisecond) {
WaitTest(DurationFromSeconds(1e-3));
}
TEST(WaitTest, WaitTestSpinOneSecond) { WaitTest(DurationFromSeconds(1)); }
} // namespace
} // namespace ruy
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}