Make sampling mode test less likely to suffer from race where training completes before profiling starts.

PiperOrigin-RevId: 336569182
Change-Id: Ic7381f0442cc0fe000d68d86afdb40f7d1e3012c
This commit is contained in:
Yi Situ 2020-10-11 15:48:20 -07:00 committed by TensorFlower Gardener
parent fe094e9679
commit 98400c0a6a

View File

@ -86,23 +86,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
profiler.start_server(port)
_, steps, train_ds, model = _model_setup()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
logging.info('worker finishing')
def on_profile(port, logdir):
# Request for 30 milliseconds of profile.
duration_ms = 30
options = profiler.ProfilerOptions(
host_tracer_level=2,
python_tracer_level=0,
device_tracer_level=1,
)
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
'', 100, options)
port = portpicker.pick_unused_port()
thread = threading.Thread(target=on_worker, args=(port,))
thread.start()
# Request for 3 seconds of profile.
duration_ms = 3000
logdir = self.get_temp_dir()
options = profiler.ProfilerOptions(
host_tracer_level=2,
python_tracer_level=0,
device_tracer_level=1,
)
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '',
3, options)
thread.join(30)
port = portpicker.pick_unused_port()
thread_profiler = threading.Thread(target=on_profile, args=(port, logdir))
thread_worker = threading.Thread(target=on_worker, args=(port,))
thread_worker.start()
thread_profiler.start()
thread_profiler.join()
thread_worker.join(120)
self._check_tools_pb_exist(logdir)
def test_single_worker_programmatic_mode(self):