STT-tensorflow/tensorflow/lite/experimental/ruy/wait_test.cc
Benoit Jacob 9385325594 Refactor WaitForVariableChange: abstract away the atomic operations from it
by putting them behind a `condition` function that it waits for, rename it
`WaitUntil`, move it to its own library with own unit-test.

A design benefit from this change is that now the atomic operations are
grouped on the caller's side, which makes it easier to verify that they
are correct (i.e. that 'release' stores match 'acquire' loads).

Another design benefit is that this WaitUntil function now handles all of the spin-waiting that we do in ruy, including the one in BlockingCounter which was using separate code. That code had its own challenge as, not depending on a condition variable, it was occasionally reverting to just waiting for fixed durations to give other threads a chance to get scheduled, which could harm performance. That concern is removed by this change making it use WaitUntil. Making WaitUntil take a std::function for the wait condition was needed to make it general enough for this purpose: the original usage in ThreadPool needs to wait until a value is *different* from a given one, but the usage in BlockingCounter needs to wait until a value is *equal* to a given one.

An overload of WaitUntil allows controlling the spin_duration, in preparation for future changes exposing this setting in the ruy API itself.

PiperOrigin-RevId: 254184689
2019-06-20 06:37:12 -07:00

97 lines
3.0 KiB
C++

/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/ruy/wait.h"
#include <atomic>
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
#include <thread> // NOLINT(build/c++11)
#include <gtest/gtest.h>
namespace ruy {
namespace {
// Thread taking a `value` atomic counter and incrementing it until it equals
// `end_value`, then notifying the condition variable as long as
// `value == end_value`. If `end_value` is increased, it will then resume
// incrementing `value`, etc. Terminates if `end_value == -1`.
class ThreadCountingUpToValue {
public:
ThreadCountingUpToValue(const std::atomic<int>& end_value,
std::atomic<int>* value,
std::condition_variable* condvar, std::mutex* mutex)
: end_value_(end_value),
value_(value),
condvar_(condvar),
mutex_(mutex) {}
void operator()() {
while (end_value_.load() != -1) {
while (value_->fetch_add(1) < end_value_.load() - 1) {
}
while (value_->load() == end_value_.load()) {
std::lock_guard<std::mutex> lock(*mutex_);
condvar_->notify_all();
}
}
}
private:
const std::atomic<int>& end_value_;
std::atomic<int>* value_;
std::condition_variable* condvar_;
std::mutex* mutex_;
};
void WaitTest(const Duration& spin_duration) {
std::condition_variable condvar;
std::mutex mutex;
std::atomic<int> value(0);
std::atomic<int> end_value(0);
ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex);
std::thread thread(thread_callable);
for (int i = 1; i < 10; i++) {
end_value.store(1000 * i);
const auto& condition = [&value, &end_value]() {
return value.load() == end_value.load();
};
ruy::WaitUntil(condition, spin_duration, &condvar, &mutex);
EXPECT_EQ(value.load(), end_value.load());
}
end_value.store(-1);
thread.join();
}
TEST(WaitTest, WaitTestNoSpin) { WaitTest(DurationFromSeconds(0)); }
TEST(WaitTest, WaitTestSpinOneMicrosecond) {
WaitTest(DurationFromSeconds(1e-6));
}
TEST(WaitTest, WaitTestSpinOneMillisecond) {
WaitTest(DurationFromSeconds(1e-3));
}
TEST(WaitTest, WaitTestSpinOneSecond) { WaitTest(DurationFromSeconds(1)); }
} // namespace
} // namespace ruy
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}