Add worker_startup delay in cluster coordinator, controlled by an env var, default to not delay.

This can improve the accuracy of some models and this is also done by Estimator. We are not sure how important it is and thus we create an env var to let users try this behavior if they see any regression from Estimator.
This is not meant to be used widely.

PiperOrigin-RevId: 338332972
Change-Id: I144e30912107c19ed7bb093a349d391a4f848a43
This commit is contained in:
Yuefeng Zhou 2020-10-21 13:52:02 -07:00 committed by TensorFlower Gardener
parent 1e6be70ea0
commit a4daf15271
2 changed files with 44 additions and 0 deletions

View File

@ -29,6 +29,7 @@ import os
import re
import sys
import threading
import time
import weakref
from six.moves import queue
@ -759,7 +760,25 @@ class Worker(object):
closure.output_remote_value._set_error(e) # pylint: disable=protected-access
self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
def _maybe_delay(self):
"""Delay if corresponding env vars are set."""
# If the following two env vars variables are set. Scheduling for workers
# will start in a staggered manner. Worker i will wait for
# `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
# `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
delay_cap = int(
os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
if delay_cap:
delay_secs = min(delay_secs * self.worker_index, delay_cap)
if delay_secs > 0:
logging.info("Worker %d sleeping for %d seconds before running function",
self.worker_index, delay_secs)
time.sleep(delay_secs)
def _process_queue(self):
"""Function running in a thread to process closure queues."""
self._maybe_delay()
while True:
closure = self._cluster._closure_queue.get() # pylint: disable=protected-access
self._process_closure(closure)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import functools
import os
import platform
import sys
import threading
@ -597,6 +598,7 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2)
def testPerWorkerValue(self):
self.skipTest('b/168569314')
var_shape = tuple()
var_dtype = dtypes.float32
var_name = 'var'
@ -664,6 +666,29 @@ class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
cls.strategy = cls.coordinator.strategy
class ScheduleStartDelayTest(ClusterCoordinatorTest):
"""Test basic functionality works with worker scheduling delay.
This is basically to make sure that setting environment variables
`TF_COORDINATOR_SCHEDULE_START_DELAY` and
`TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure.
"""
@classmethod
def setUpClass(cls):
super(ScheduleStartDelayTest, cls).setUpClass()
os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2'
os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4'
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
cls.strategy = cls.coordinator.strategy
@classmethod
def tearDownClass(cls):
del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY']
del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX']
super(ScheduleStartDelayTest, cls).tearDownClass()
class ErrorReportingTest(TestCaseWithErrorReportingThread):
@classmethod