diff --git a/tensorflow/lite/experimental/ruy/BUILD b/tensorflow/lite/experimental/ruy/BUILD index 91ff7130163..aa621e3f53e 100644 --- a/tensorflow/lite/experimental/ruy/BUILD +++ b/tensorflow/lite/experimental/ruy/BUILD @@ -28,6 +28,22 @@ cc_library( hdrs = ["time.h"], ) +cc_library( + name = "wait", + srcs = ["wait.cc"], + hdrs = ["wait.h"], + deps = [":time"], +) + +cc_test( + name = "wait_test", + srcs = ["wait_test.cc"], + deps = [ + ":wait", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "spec", hdrs = ["spec.h"], @@ -119,7 +135,7 @@ cc_library( ], deps = [ ":check_macros", - ":time", + ":wait", ], ) @@ -135,7 +151,7 @@ cc_library( deps = [ ":blocking_counter", ":check_macros", - ":time", + ":wait", ], ) diff --git a/tensorflow/lite/experimental/ruy/blocking_counter.cc b/tensorflow/lite/experimental/ruy/blocking_counter.cc index 4519f490c4f..ac8a32803fd 100644 --- a/tensorflow/lite/experimental/ruy/blocking_counter.cc +++ b/tensorflow/lite/experimental/ruy/blocking_counter.cc @@ -15,62 +15,38 @@ limitations under the License. #include "tensorflow/lite/experimental/ruy/blocking_counter.h" -#include -#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) #include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/wait.h" namespace ruy { -static constexpr double kBlockingCounterMaxBusyWaitSeconds = 2e-3; - -void BlockingCounter::Reset(std::size_t initial_count) { - std::size_t old_count_value = count_.load(std::memory_order_relaxed); +void BlockingCounter::Reset(int initial_count) { + int old_count_value = count_.load(std::memory_order_relaxed); RUY_DCHECK_EQ(old_count_value, 0); (void)old_count_value; count_.store(initial_count, std::memory_order_release); } bool BlockingCounter::DecrementCount() { - std::size_t old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); + int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); RUY_DCHECK_GT(old_count_value, 0); - std::size_t count_value = old_count_value - 1; - return count_value == 0; + int count_value = old_count_value - 1; + bool hit_zero = (count_value == 0); + if (hit_zero) { + std::lock_guard lock(count_mutex_); + count_cond_.notify_all(); + } + return hit_zero; } void BlockingCounter::Wait() { - // Busy-wait until the count value is 0. - const Duration wait_duration = - DurationFromSeconds(kBlockingCounterMaxBusyWaitSeconds); - TimePoint wait_start = Clock::now(); - while (count_.load(std::memory_order_acquire)) { - if (Clock::now() - wait_start > wait_duration) { - // If we are unlucky, the blocking thread (that calls DecrementCount) - // and the blocked thread (here, calling Wait) may be scheduled on - // the same CPU, so the busy-waiting of the present thread may prevent - // the blocking thread from resuming and unblocking. - // If we are even unluckier, the priorities of the present thread - // might be higher than that of the blocking thread, so just yielding - // wouldn't allow the blocking thread to resume. So we sleep for - // a substantial amount of time in that case. Notice that we only - // do so after having busy-waited for kBlockingCounterMaxBusyWaitSeconds, - // which is typically >= 1 millisecond, so sleeping 1 more millisecond - // isn't terrible at that point. - // - // How this is mitigated in practice: - // In practice, it is well known that the application should be - // conservative in choosing how many threads to tell gemmlowp to use, - // as it's hard to know how many CPU cores it will get to run on, - // on typical mobile devices. - // It seems impossible for gemmlowp to make this choice automatically, - // which is why gemmlowp's default is to use only 1 thread, and - // applications may override that if they know that they can count on - // using more than that. - std::this_thread::sleep_for(DurationFromSeconds(1e-3)); - wait_start = Clock::now(); - } - } + const auto& condition = [this]() { + return count_.load(std::memory_order_acquire) == 0; + }; + WaitUntil(condition, &count_cond_, &count_mutex_); } } // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/blocking_counter.h b/tensorflow/lite/experimental/ruy/blocking_counter.h index a3c34509143..40f903ba1ab 100644 --- a/tensorflow/lite/experimental/ruy/blocking_counter.h +++ b/tensorflow/lite/experimental/ruy/blocking_counter.h @@ -17,7 +17,8 @@ limitations under the License. #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ #include -#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) namespace ruy { @@ -35,7 +36,7 @@ class BlockingCounter { // Sets/resets the counter; initial_count is the number of // decrementing events that the Wait() call will be waiting for. - void Reset(std::size_t initial_count); + void Reset(int initial_count); // Decrements the counter; if the counter hits zero, signals // the threads that were waiting for that, and returns true. @@ -48,7 +49,12 @@ class BlockingCounter { void Wait(); private: - std::atomic count_; + std::atomic count_; + + // The condition variable and mutex allowing to passively wait for count_ + // to reach the value zero, in the case of longer waits. + std::condition_variable count_cond_; + std::mutex count_mutex_; }; } // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/thread_pool.cc index de65a62bdf4..db69dc8cc94 100644 --- a/tensorflow/lite/experimental/ruy/thread_pool.cc +++ b/tensorflow/lite/experimental/ruy/thread_pool.cc @@ -23,85 +23,10 @@ limitations under the License. #include "tensorflow/lite/experimental/ruy/blocking_counter.h" #include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/wait.h" namespace ruy { -// This value was empirically derived on an end-to-end application benchmark. -// That this value means that we may be sleeping substantially longer -// than a scheduler timeslice's duration is not necessarily surprising. The -// idea is to pick up quickly new work after having finished the previous -// workload. When it's new work within the same GEMM as the previous work, the -// time interval that we might be busy-waiting is very small, so for that -// purpose it would be more than enough to sleep for 1 ms. -// That is all what we would observe on a GEMM benchmark. However, in a real -// application, after having finished a GEMM, we might do unrelated work for -// a little while, then start on a new GEMM. Think of a neural network -// application performing inference, where many but not all layers are -// implemented by a GEMM. In such cases, our worker threads might be idle for -// longer periods of time before having work again. If we let them passively -// wait, on a mobile device, the CPU scheduler might aggressively clock down -// or even turn off the CPU cores that they were running on. That would result -// in a long delay the next time these need to be turned back on for the next -// GEMM. So we need to strike a balance that reflects typical time intervals -// between consecutive GEMM invokations, not just intra-GEMM considerations. -// Of course, we need to balance keeping CPUs spinning longer to resume work -// faster, versus passively waiting to conserve power. -static constexpr double kThreadPoolMaxBusyWaitSeconds = 2e-3; - -// Waits until *var != initial_value. -// -// Returns the new value of *var. The guarantee here is that -// the return value is different from initial_value, and that that -// new value has been taken by *var at some point during the -// execution of this function. There is no guarantee that this is -// still the value of *var when this function returns, since *var is -// not assumed to be guarded by any lock. -// -// First does some busy-waiting for a fixed number of no-op cycles, -// then falls back to passive waiting for the given condvar, guarded -// by the given mutex. -// -// The idea of doing some initial busy-waiting is to help get -// better and more consistent multithreading benefits for small GEMM sizes. -// Busy-waiting help ensuring that if we need to wake up soon after having -// started waiting, then we can wake up quickly (as opposed to, say, -// having to wait to be scheduled again by the OS). On the other hand, -// we must still eventually revert to passive waiting for longer waits -// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) -// so as to avoid permanently spinning. -// -template -T WaitForVariableChange(std::atomic* var, T initial_value, - std::condition_variable* cond, std::mutex* mutex) { - // First, trivial case where the variable already changed value. - T new_value = var->load(std::memory_order_acquire); - if (new_value != initial_value) { - return new_value; - } - // Then try busy-waiting. - const Duration wait_duration = - DurationFromSeconds(kThreadPoolMaxBusyWaitSeconds); - const TimePoint wait_start = Clock::now(); - while (Clock::now() - wait_start < wait_duration) { - new_value = var->load(std::memory_order_acquire); - if (new_value != initial_value) { - return new_value; - } - } - // Finally, do real passive waiting. - mutex->lock(); - new_value = var->load(std::memory_order_acquire); - while (new_value == initial_value) { - std::unique_lock lock(*mutex, std::adopt_lock); - cond->wait(lock); - lock.release(); - new_value = var->load(std::memory_order_acquire); - } - mutex->unlock(); - return new_value; -} - // A worker thread. class Thread { public: @@ -184,14 +109,15 @@ class Thread { // Thread main loop while (true) { - // Get a state to act on // In the 'Ready' state, we have nothing to do but to wait until // we switch to another state. - State state_to_act_upon = WaitForVariableChange( - &state_, State::Ready, &state_cond_, &state_mutex_); + const auto& condition = [this]() { + return state_.load(std::memory_order_acquire) != State::Ready; + }; + WaitUntil(condition, &state_cond_, &state_mutex_); - // We now have a state to act on, so act. - switch (state_to_act_upon) { + // Act on new state. + switch (state_.load(std::memory_order_acquire)) { case State::HasWork: // Got work to do! So do it, and then revert to 'Ready' state. ChangeState(State::Ready); diff --git a/tensorflow/lite/experimental/ruy/wait.cc b/tensorflow/lite/experimental/ruy/wait.cc new file mode 100644 index 00000000000..202a4158272 --- /dev/null +++ b/tensorflow/lite/experimental/ruy/wait.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/ruy/wait.h" + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "tensorflow/lite/experimental/ruy/time.h" + +namespace ruy { + +void WaitUntil(const std::function& condition, + const Duration& spin_duration, std::condition_variable* condvar, + std::mutex* mutex) { + // First, trivial case where the `condition` is already true; + if (condition()) { + return; + } + + // Then try busy-waiting. + const TimePoint wait_start = Clock::now(); + while (Clock::now() - wait_start < spin_duration) { + if (condition()) { + return; + } + } + + // Finally, do real passive waiting. + // + // TODO(b/135624397): We really want wait_until(TimePoint::max()) but that + // runs into a libc++ bug at the moment, see b/135624397 and + // https://bugs.llvm.org/show_bug.cgi?id=21395#c5. We pick a duration large + // enough to appear infinite in practice and small enough to avoid such + // overflow bugs... + const Duration& timeout = DurationFromSeconds(1e6); + std::unique_lock lock(*mutex); + condvar->wait_for(lock, timeout, condition); +} + +void WaitUntil(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex) { + // This value was empirically derived with some microbenchmark, we don't have + // high confidence in it. + // + // TODO(b/135595069): make this value configurable at runtime. + // I almost wanted to file another bug to ask for experimenting in a more + // principled way to tune this value better, but this would have to be tuned + // on real end-to-end applications and we'd expect different applications to + // require different tunings. So the more important point is the need for + // this to be controllable by the application. + // + // That this value means that we may be sleeping substantially longer + // than a scheduler timeslice's duration is not necessarily surprising. The + // idea is to pick up quickly new work after having finished the previous + // workload. When it's new work within the same GEMM as the previous work, the + // time interval that we might be busy-waiting is very small, so for that + // purpose it would be more than enough to sleep for 1 ms. + // That is all what we would observe on a GEMM benchmark. However, in a real + // application, after having finished a GEMM, we might do unrelated work for + // a little while, then start on a new GEMM. In that case the wait interval + // may be a little longer. There may also not be another GEMM for a long time, + // in which case we'll end up passively waiting below. + const double kMaxBusyWaitSeconds = 2e-3; + const Duration spin_duration = DurationFromSeconds(kMaxBusyWaitSeconds); + WaitUntil(condition, spin_duration, condvar, mutex); +} + +} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/wait.h b/tensorflow/lite/experimental/ruy/wait.h new file mode 100644 index 00000000000..df4f3e32dba --- /dev/null +++ b/tensorflow/lite/experimental/ruy/wait.h @@ -0,0 +1,74 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "tensorflow/lite/experimental/ruy/time.h" + +namespace ruy { + +// Waits until some evaluation of `condition` has returned true. +// +// There is no guarantee that calling `condition` again after this function +// has returned would still return true. The only +// contract is that at some point during the execution of that function, +// `condition` has returned true. +// +// First does some spin-waiting for the specified `spin_duration`, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. At this point it will try to acquire the mutex lock, +// around the waiting on the condition variable. +// Therefore, this function expects that the calling thread hasn't already +// locked the mutex before calling it. +// This function will always release the mutex lock before returning. +// +// The idea of doing some initial spin-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Spin-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +// In situations where other threads might have more useful things to do with +// these CPU cores than our spin-waiting, it may be best to reduce the value +// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. +// +// There is a risk that the std::function used here might use a heap allocation +// to store its context. The expected usage pattern is that these functions' +// contexts will consist of a single pointer value (typically capturing only +// [this]), and that in this case the std::function implementation will use +// inline storage, avoiding a heap allocation. However, we can't effectively +// guard that assumption, and that's not a big concern anyway because the +// latency of a small heap allocation is probably low compared to the intrinsic +// latency of what this WaitUntil function does. +void WaitUntil(const std::function& condition, + const Duration& spin_duration, std::condition_variable* condvar, + std::mutex* mutex); + +// Convenience overload using a default `spin_duration`. +// TODO(benoitjacob): let this be controlled from the ruy API. +void WaitUntil(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ diff --git a/tensorflow/lite/experimental/ruy/wait_test.cc b/tensorflow/lite/experimental/ruy/wait_test.cc new file mode 100644 index 00000000000..a19d8c85860 --- /dev/null +++ b/tensorflow/lite/experimental/ruy/wait_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/ruy/wait.h" + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include + +namespace ruy { +namespace { + +// Thread taking a `value` atomic counter and incrementing it until it equals +// `end_value`, then notifying the condition variable as long as +// `value == end_value`. If `end_value` is increased, it will then resume +// incrementing `value`, etc. Terminates if `end_value == -1`. +class ThreadCountingUpToValue { + public: + ThreadCountingUpToValue(const std::atomic& end_value, + std::atomic* value, + std::condition_variable* condvar, std::mutex* mutex) + : end_value_(end_value), + value_(value), + condvar_(condvar), + mutex_(mutex) {} + void operator()() { + while (end_value_.load() != -1) { + while (value_->fetch_add(1) < end_value_.load() - 1) { + } + while (value_->load() == end_value_.load()) { + std::lock_guard lock(*mutex_); + condvar_->notify_all(); + } + } + } + + private: + const std::atomic& end_value_; + std::atomic* value_; + std::condition_variable* condvar_; + std::mutex* mutex_; +}; + +void WaitTest(const Duration& spin_duration) { + std::condition_variable condvar; + std::mutex mutex; + std::atomic value(0); + std::atomic end_value(0); + ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); + std::thread thread(thread_callable); + for (int i = 1; i < 10; i++) { + end_value.store(1000 * i); + const auto& condition = [&value, &end_value]() { + return value.load() == end_value.load(); + }; + ruy::WaitUntil(condition, spin_duration, &condvar, &mutex); + EXPECT_EQ(value.load(), end_value.load()); + } + end_value.store(-1); + thread.join(); +} + +TEST(WaitTest, WaitTestNoSpin) { WaitTest(DurationFromSeconds(0)); } + +TEST(WaitTest, WaitTestSpinOneMicrosecond) { + WaitTest(DurationFromSeconds(1e-6)); +} + +TEST(WaitTest, WaitTestSpinOneMillisecond) { + WaitTest(DurationFromSeconds(1e-3)); +} + +TEST(WaitTest, WaitTestSpinOneSecond) { WaitTest(DurationFromSeconds(1)); } + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}