diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 79f2a955d51..a87913deafe 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -29,9 +29,10 @@ namespace { using error::Code; -void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) { +void WaitForStopThread(Coordinator* coord, Notification* about_to_wait, + Notification* done) { + about_to_wait->Notify(); coord->WaitForStop(); - *stopped = true; done->Notify(); } @@ -39,17 +40,17 @@ TEST(CoordinatorTest, TestStopAndWaitOnStop) { Coordinator coord; EXPECT_EQ(coord.ShouldStop(), false); - bool stopped = false; + Notification about_to_wait; Notification done; Env::Default()->SchedClosure( - std::bind(&WaitForStopThread, &coord, &stopped, &done)); - Env::Default()->SleepForMicroseconds(10000000); - EXPECT_EQ(stopped, false); + std::bind(&WaitForStopThread, &coord, &about_to_wait, &done)); + about_to_wait.WaitForNotification(); + Env::Default()->SleepForMicroseconds(1000 * 1000); + EXPECT_FALSE(done.HasBeenNotified()); TF_EXPECT_OK(coord.RequestStop()); done.WaitForNotification(); - EXPECT_EQ(stopped, true); - EXPECT_EQ(coord.ShouldStop(), true); + EXPECT_TRUE(coord.ShouldStop()); } class MockQueueRunner : public RunnerInterface { @@ -66,14 +67,16 @@ class MockQueueRunner : public RunnerInterface { join_counter_ = join_counter; } - void StartCounting(std::atomic<int>* counter, int until) { + void StartCounting(std::atomic<int>* counter, int until, + Notification* start = nullptr) { thread_pool_->Schedule( - std::bind(&MockQueueRunner::CountThread, this, counter, until)); + std::bind(&MockQueueRunner::CountThread, this, counter, until, start)); } - void StartSettingStatus(const Status& status, BlockingCounter* counter) { - thread_pool_->Schedule( - std::bind(&MockQueueRunner::SetStatusThread, this, status, counter)); + void StartSettingStatus(const Status& status, BlockingCounter* counter, + Notification* start) { + thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this, + status, counter, start)); } Status Join() { @@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface { void Stop() { stopped_ = true; } private: - void CountThread(std::atomic<int>* counter, int until) { + void CountThread(std::atomic<int>* counter, int until, Notification* start) { + if (start != nullptr) start->WaitForNotification(); while (!coord_->ShouldStop() && counter->load() < until) { (*counter)++; - Env::Default()->SleepForMicroseconds(100000); + Env::Default()->SleepForMicroseconds(10 * 1000); } coord_->RequestStop().IgnoreError(); } - void SetStatusThread(const Status& status, BlockingCounter* counter) { - Env::Default()->SleepForMicroseconds(100000); + void SetStatusThread(const Status& status, BlockingCounter* counter, + Notification* start) { + start->WaitForNotification(); SetStatus(status); counter->DecrementCount(); } @@ -130,7 +135,7 @@ TEST(CoordinatorTest, TestRealStop) { TF_EXPECT_OK(coord.RequestStop()); int temp_counter = counter.load(); - Env::Default()->SleepForMicroseconds(10000000); + Env::Default()->SleepForMicroseconds(1000 * 1000); EXPECT_EQ(temp_counter, counter.load()); TF_EXPECT_OK(coord.Join()); } @@ -138,12 +143,14 @@ TEST(CoordinatorTest, TestRealStop) { TEST(CoordinatorTest, TestRequestStop) { Coordinator coord; std::atomic<int> counter(0); + Notification start; std::unique_ptr<MockQueueRunner> qr; for (int i = 0; i < 10; i++) { qr.reset(new MockQueueRunner(&coord)); - qr->StartCounting(&counter, 10); + qr->StartCounting(&counter, 10, &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr))); } + start.Notify(); coord.WaitForStop(); EXPECT_EQ(coord.ShouldStop(), true); @@ -168,20 +175,22 @@ TEST(CoordinatorTest, TestJoin) { TEST(CoordinatorTest, StatusReporting) { Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE}); + Notification start; BlockingCounter counter(3); std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord)); - qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter); + qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord)); - qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter); + qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord)); - qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter); + qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3))); + start.Notify(); counter.Wait(); TF_EXPECT_OK(coord.RequestStop()); EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);