fix race condition in profiler_api_test.py.

PiperOrigin-RevId: 339559967
Change-Id: I9a3f7fd0d8117bc8ae835f549ed8cd057e7d1203
This commit is contained in:
A. Unique TensorFlower 2020-10-28 16:20:09 -07:00 committed by TensorFlower Gardener
parent f9818f12c3
commit ad891c3a24

View File

@ -59,6 +59,11 @@ def _make_temp_log_dir(test_obj):
class ProfilerApiTest(test_util.TensorFlowTestCase):
def setUp(self):
super().setUp()
self.worker_start = threading.Event()
self.profile_done = False
def _check_tools_pb_exist(self, logdir):
expected_files = [
'overview_page.pb',
@ -86,16 +91,21 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
def test_single_worker_sampling_mode(self, delay_ms=None):
"""Test single worker sampling mode."""
def on_worker(port):
def on_worker(port, worker_start):
logging.info('worker starting server on {}'.format(port))
profiler.start_server(port)
_, steps, train_ds, model = _model_setup()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
worker_start.set()
while True:
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
if self.profile_done:
break
def on_profile(port, logdir):
def on_profile(port, logdir, worker_start):
# Request for 30 milliseconds of profile.
duration_ms = 30
worker_start.wait()
options = profiler.ProfilerOptions(
host_tracer_level=2,
python_tracer_level=0,
@ -106,16 +116,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
'', 100, options)
self.profile_done = True
logdir = self.get_temp_dir()
port = portpicker.pick_unused_port()
thread_profiler = threading.Thread(target=on_profile, args=(port, logdir))
thread_worker = threading.Thread(target=on_worker, args=(port,))
thread_profiler = threading.Thread(
target=on_profile, args=(port, logdir, self.worker_start))
thread_worker = threading.Thread(
target=on_worker, args=(port, self.worker_start))
thread_worker.start()
thread_profiler.start()
thread_profiler.join()
thread_worker.join(120)
self._check_xspace_pb_exist(logdir)
def test_single_worker_sampling_mode_short_delay(self):
"""Test single worker sampling mode with a short delay.
Expect that requested delayed start time will arrive late, and a subsequent
retry will issue an immediate start.
"""
self.test_single_worker_sampling_mode(delay_ms=1)
def test_single_worker_sampling_mode_long_delay(self):
"""Test single worker sampling mode with a long delay."""