fix race condition in profiler_api_test.py.
PiperOrigin-RevId: 339559967 Change-Id: I9a3f7fd0d8117bc8ae835f549ed8cd057e7d1203
This commit is contained in:
parent
f9818f12c3
commit
ad891c3a24
@ -59,6 +59,11 @@ def _make_temp_log_dir(test_obj):
|
||||
|
||||
class ProfilerApiTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.worker_start = threading.Event()
|
||||
self.profile_done = False
|
||||
|
||||
def _check_tools_pb_exist(self, logdir):
|
||||
expected_files = [
|
||||
'overview_page.pb',
|
||||
@ -86,16 +91,21 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
||||
def test_single_worker_sampling_mode(self, delay_ms=None):
|
||||
"""Test single worker sampling mode."""
|
||||
|
||||
def on_worker(port):
|
||||
def on_worker(port, worker_start):
|
||||
logging.info('worker starting server on {}'.format(port))
|
||||
profiler.start_server(port)
|
||||
_, steps, train_ds, model = _model_setup()
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||
worker_start.set()
|
||||
while True:
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||
if self.profile_done:
|
||||
break
|
||||
|
||||
def on_profile(port, logdir):
|
||||
def on_profile(port, logdir, worker_start):
|
||||
# Request for 30 milliseconds of profile.
|
||||
duration_ms = 30
|
||||
|
||||
worker_start.wait()
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=2,
|
||||
python_tracer_level=0,
|
||||
@ -106,16 +116,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
||||
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
|
||||
'', 100, options)
|
||||
|
||||
self.profile_done = True
|
||||
|
||||
logdir = self.get_temp_dir()
|
||||
port = portpicker.pick_unused_port()
|
||||
thread_profiler = threading.Thread(target=on_profile, args=(port, logdir))
|
||||
thread_worker = threading.Thread(target=on_worker, args=(port,))
|
||||
thread_profiler = threading.Thread(
|
||||
target=on_profile, args=(port, logdir, self.worker_start))
|
||||
thread_worker = threading.Thread(
|
||||
target=on_worker, args=(port, self.worker_start))
|
||||
thread_worker.start()
|
||||
thread_profiler.start()
|
||||
thread_profiler.join()
|
||||
thread_worker.join(120)
|
||||
self._check_xspace_pb_exist(logdir)
|
||||
|
||||
def test_single_worker_sampling_mode_short_delay(self):
|
||||
"""Test single worker sampling mode with a short delay.
|
||||
|
||||
Expect that requested delayed start time will arrive late, and a subsequent
|
||||
retry will issue an immediate start.
|
||||
"""
|
||||
|
||||
self.test_single_worker_sampling_mode(delay_ms=1)
|
||||
|
||||
def test_single_worker_sampling_mode_long_delay(self):
|
||||
"""Test single worker sampling mode with a long delay."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user