fix race condition in profiler_api_test.py.
PiperOrigin-RevId: 339559967 Change-Id: I9a3f7fd0d8117bc8ae835f549ed8cd057e7d1203
This commit is contained in:
parent
f9818f12c3
commit
ad891c3a24
@ -59,6 +59,11 @@ def _make_temp_log_dir(test_obj):
|
|||||||
|
|
||||||
class ProfilerApiTest(test_util.TensorFlowTestCase):
|
class ProfilerApiTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.worker_start = threading.Event()
|
||||||
|
self.profile_done = False
|
||||||
|
|
||||||
def _check_tools_pb_exist(self, logdir):
|
def _check_tools_pb_exist(self, logdir):
|
||||||
expected_files = [
|
expected_files = [
|
||||||
'overview_page.pb',
|
'overview_page.pb',
|
||||||
@ -86,16 +91,21 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
|||||||
def test_single_worker_sampling_mode(self, delay_ms=None):
|
def test_single_worker_sampling_mode(self, delay_ms=None):
|
||||||
"""Test single worker sampling mode."""
|
"""Test single worker sampling mode."""
|
||||||
|
|
||||||
def on_worker(port):
|
def on_worker(port, worker_start):
|
||||||
logging.info('worker starting server on {}'.format(port))
|
logging.info('worker starting server on {}'.format(port))
|
||||||
profiler.start_server(port)
|
profiler.start_server(port)
|
||||||
_, steps, train_ds, model = _model_setup()
|
_, steps, train_ds, model = _model_setup()
|
||||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
worker_start.set()
|
||||||
|
while True:
|
||||||
|
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||||
|
if self.profile_done:
|
||||||
|
break
|
||||||
|
|
||||||
def on_profile(port, logdir):
|
def on_profile(port, logdir, worker_start):
|
||||||
# Request for 30 milliseconds of profile.
|
# Request for 30 milliseconds of profile.
|
||||||
duration_ms = 30
|
duration_ms = 30
|
||||||
|
|
||||||
|
worker_start.wait()
|
||||||
options = profiler.ProfilerOptions(
|
options = profiler.ProfilerOptions(
|
||||||
host_tracer_level=2,
|
host_tracer_level=2,
|
||||||
python_tracer_level=0,
|
python_tracer_level=0,
|
||||||
@ -106,16 +116,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
|||||||
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
|
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
|
||||||
'', 100, options)
|
'', 100, options)
|
||||||
|
|
||||||
|
self.profile_done = True
|
||||||
|
|
||||||
logdir = self.get_temp_dir()
|
logdir = self.get_temp_dir()
|
||||||
port = portpicker.pick_unused_port()
|
port = portpicker.pick_unused_port()
|
||||||
thread_profiler = threading.Thread(target=on_profile, args=(port, logdir))
|
thread_profiler = threading.Thread(
|
||||||
thread_worker = threading.Thread(target=on_worker, args=(port,))
|
target=on_profile, args=(port, logdir, self.worker_start))
|
||||||
|
thread_worker = threading.Thread(
|
||||||
|
target=on_worker, args=(port, self.worker_start))
|
||||||
thread_worker.start()
|
thread_worker.start()
|
||||||
thread_profiler.start()
|
thread_profiler.start()
|
||||||
thread_profiler.join()
|
thread_profiler.join()
|
||||||
thread_worker.join(120)
|
thread_worker.join(120)
|
||||||
self._check_xspace_pb_exist(logdir)
|
self._check_xspace_pb_exist(logdir)
|
||||||
|
|
||||||
|
def test_single_worker_sampling_mode_short_delay(self):
|
||||||
|
"""Test single worker sampling mode with a short delay.
|
||||||
|
|
||||||
|
Expect that requested delayed start time will arrive late, and a subsequent
|
||||||
|
retry will issue an immediate start.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.test_single_worker_sampling_mode(delay_ms=1)
|
||||||
|
|
||||||
def test_single_worker_sampling_mode_long_delay(self):
|
def test_single_worker_sampling_mode_long_delay(self):
|
||||||
"""Test single worker sampling mode with a long delay."""
|
"""Test single worker sampling mode with a long delay."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user