Remove race in coordinator_test so it passes on tsan. Reduce sleep intervals

so it runs faster (11s on tsan instead of ~30s).
Change: 147893428
This commit is contained in:
A. Unique TensorFlower 2017-02-17 16:25:30 -08:00 committed by TensorFlower Gardener
parent 79c3b47319
commit eb9624017a

View File

@ -29,9 +29,10 @@ namespace {
using error::Code; using error::Code;
void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) { void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
Notification* done) {
about_to_wait->Notify();
coord->WaitForStop(); coord->WaitForStop();
*stopped = true;
done->Notify(); done->Notify();
} }
@ -39,17 +40,17 @@ TEST(CoordinatorTest, TestStopAndWaitOnStop) {
Coordinator coord; Coordinator coord;
EXPECT_EQ(coord.ShouldStop(), false); EXPECT_EQ(coord.ShouldStop(), false);
bool stopped = false; Notification about_to_wait;
Notification done; Notification done;
Env::Default()->SchedClosure( Env::Default()->SchedClosure(
std::bind(&WaitForStopThread, &coord, &stopped, &done)); std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
Env::Default()->SleepForMicroseconds(10000000); about_to_wait.WaitForNotification();
EXPECT_EQ(stopped, false); Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_FALSE(done.HasBeenNotified());
TF_EXPECT_OK(coord.RequestStop()); TF_EXPECT_OK(coord.RequestStop());
done.WaitForNotification(); done.WaitForNotification();
EXPECT_EQ(stopped, true); EXPECT_TRUE(coord.ShouldStop());
EXPECT_EQ(coord.ShouldStop(), true);
} }
class MockQueueRunner : public RunnerInterface { class MockQueueRunner : public RunnerInterface {
@ -66,14 +67,16 @@ class MockQueueRunner : public RunnerInterface {
join_counter_ = join_counter; join_counter_ = join_counter;
} }
void StartCounting(std::atomic<int>* counter, int until) { void StartCounting(std::atomic<int>* counter, int until,
Notification* start = nullptr) {
thread_pool_->Schedule( thread_pool_->Schedule(
std::bind(&MockQueueRunner::CountThread, this, counter, until)); std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
} }
void StartSettingStatus(const Status& status, BlockingCounter* counter) { void StartSettingStatus(const Status& status, BlockingCounter* counter,
thread_pool_->Schedule( Notification* start) {
std::bind(&MockQueueRunner::SetStatusThread, this, status, counter)); thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
status, counter, start));
} }
Status Join() { Status Join() {
@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface {
void Stop() { stopped_ = true; } void Stop() { stopped_ = true; }
private: private:
void CountThread(std::atomic<int>* counter, int until) { void CountThread(std::atomic<int>* counter, int until, Notification* start) {
if (start != nullptr) start->WaitForNotification();
while (!coord_->ShouldStop() && counter->load() < until) { while (!coord_->ShouldStop() && counter->load() < until) {
(*counter)++; (*counter)++;
Env::Default()->SleepForMicroseconds(100000); Env::Default()->SleepForMicroseconds(10 * 1000);
} }
coord_->RequestStop().IgnoreError(); coord_->RequestStop().IgnoreError();
} }
void SetStatusThread(const Status& status, BlockingCounter* counter) { void SetStatusThread(const Status& status, BlockingCounter* counter,
Env::Default()->SleepForMicroseconds(100000); Notification* start) {
start->WaitForNotification();
SetStatus(status); SetStatus(status);
counter->DecrementCount(); counter->DecrementCount();
} }
@ -130,7 +135,7 @@ TEST(CoordinatorTest, TestRealStop) {
TF_EXPECT_OK(coord.RequestStop()); TF_EXPECT_OK(coord.RequestStop());
int temp_counter = counter.load(); int temp_counter = counter.load();
Env::Default()->SleepForMicroseconds(10000000); Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_EQ(temp_counter, counter.load()); EXPECT_EQ(temp_counter, counter.load());
TF_EXPECT_OK(coord.Join()); TF_EXPECT_OK(coord.Join());
} }
@ -138,12 +143,14 @@ TEST(CoordinatorTest, TestRealStop) {
TEST(CoordinatorTest, TestRequestStop) { TEST(CoordinatorTest, TestRequestStop) {
Coordinator coord; Coordinator coord;
std::atomic<int> counter(0); std::atomic<int> counter(0);
Notification start;
std::unique_ptr<MockQueueRunner> qr; std::unique_ptr<MockQueueRunner> qr;
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
qr.reset(new MockQueueRunner(&coord)); qr.reset(new MockQueueRunner(&coord));
qr->StartCounting(&counter, 10); qr->StartCounting(&counter, 10, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr))); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
} }
start.Notify();
coord.WaitForStop(); coord.WaitForStop();
EXPECT_EQ(coord.ShouldStop(), true); EXPECT_EQ(coord.ShouldStop(), true);
@ -168,20 +175,22 @@ TEST(CoordinatorTest, TestJoin) {
TEST(CoordinatorTest, StatusReporting) { TEST(CoordinatorTest, StatusReporting) {
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE}); Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
Notification start;
BlockingCounter counter(3); BlockingCounter counter(3);
std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord)); std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter); qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord)); std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter); qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord)); std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter); qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3))); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));
start.Notify();
counter.Wait(); counter.Wait();
TF_EXPECT_OK(coord.RequestStop()); TF_EXPECT_OK(coord.RequestStop());
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT); EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);