Destroy batch_thread_pool_ first. For the new in flight batches implementation, the callbacks scheduled on this thread pool refer to other class members which must remain alive until the thread pool is empty.

Also fix a similar lifetime issue in the unit tests.

PiperOrigin-RevId: 179726389
This commit is contained in:
A. Unique TensorFlower 2017-12-20 13:31:06 -08:00 committed by TensorFlower Gardener
parent 47249f349d
commit 3d01a46171
2 changed files with 14 additions and 10 deletions

View File

@ -93,6 +93,11 @@ class AdaptiveSharedBatchScheduler
: public std::enable_shared_from_this< : public std::enable_shared_from_this<
AdaptiveSharedBatchScheduler<TaskType>> { AdaptiveSharedBatchScheduler<TaskType>> {
public: public:
~AdaptiveSharedBatchScheduler() {
// Finish processing batches before destorying other class members.
batch_thread_pool_.reset();
}
struct Options { struct Options {
// The name to use for the pool of batch threads. // The name to use for the pool of batch threads.
string thread_pool_name = {"batch_threads"}; string thread_pool_name = {"batch_threads"};

View File

@ -450,10 +450,6 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) {
options.use_in_flight_batches_implementation = true; options.use_in_flight_batches_implementation = true;
options.initial_in_flight_batches_limit = 2; options.initial_in_flight_batches_limit = 2;
options.batches_to_average_over = 1000; options.batches_to_average_over = 1000;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
mutex mu; mutex mu;
int processed_batches = 0; int processed_batches = 0;
Notification finish_processing; Notification finish_processing;
@ -474,7 +470,10 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) {
} }
finish_processing.WaitForNotification(); finish_processing.WaitForNotification();
}; };
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Enqueue 3 batches. // Enqueue 3 batches.
@ -494,10 +493,6 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) {
options.use_in_flight_batches_implementation = true; options.use_in_flight_batches_implementation = true;
options.initial_in_flight_batches_limit = 2; options.initial_in_flight_batches_limit = 2;
options.batches_to_average_over = 1; options.batches_to_average_over = 1;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
auto queue_callback = [&env](std::unique_ptr<Batch<FakeTask>> batch) { auto queue_callback = [&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed()); ASSERT_TRUE(batch->IsClosed());
switch (batch->size()) { switch (batch->size()) {
@ -515,8 +510,12 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) {
break; break;
} }
}; };
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
TF_ASSERT_OK(ScheduleTask(0, queue.get())); TF_ASSERT_OK(ScheduleTask(0, queue.get()));
double in_flight_batches_limit = 2; double in_flight_batches_limit = 2;
while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) {