Make sampling mode test less likely to suffer from race where training completes before profiling starts.
PiperOrigin-RevId: 336569182 Change-Id: Ic7381f0442cc0fe000d68d86afdb40f7d1e3012c
This commit is contained in:
parent
fe094e9679
commit
98400c0a6a
@ -86,23 +86,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase):
|
||||
profiler.start_server(port)
|
||||
_, steps, train_ds, model = _model_setup()
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||
logging.info('worker finishing')
|
||||
|
||||
def on_profile(port, logdir):
|
||||
# Request for 30 milliseconds of profile.
|
||||
duration_ms = 30
|
||||
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=2,
|
||||
python_tracer_level=0,
|
||||
device_tracer_level=1,
|
||||
)
|
||||
|
||||
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
|
||||
'', 100, options)
|
||||
|
||||
port = portpicker.pick_unused_port()
|
||||
thread = threading.Thread(target=on_worker, args=(port,))
|
||||
thread.start()
|
||||
# Request for 3 seconds of profile.
|
||||
duration_ms = 3000
|
||||
logdir = self.get_temp_dir()
|
||||
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=2,
|
||||
python_tracer_level=0,
|
||||
device_tracer_level=1,
|
||||
)
|
||||
|
||||
profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '',
|
||||
3, options)
|
||||
thread.join(30)
|
||||
port = portpicker.pick_unused_port()
|
||||
thread_profiler = threading.Thread(target=on_profile, args=(port, logdir))
|
||||
thread_worker = threading.Thread(target=on_worker, args=(port,))
|
||||
thread_worker.start()
|
||||
thread_profiler.start()
|
||||
thread_profiler.join()
|
||||
thread_worker.join(120)
|
||||
self._check_tools_pb_exist(logdir)
|
||||
|
||||
def test_single_worker_programmatic_mode(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user