diff --git a/tensorflow/python/profiler/integration_test/profiler_api_test.py b/tensorflow/python/profiler/integration_test/profiler_api_test.py index e7785dbc697..b19f0595583 100644 --- a/tensorflow/python/profiler/integration_test/profiler_api_test.py +++ b/tensorflow/python/profiler/integration_test/profiler_api_test.py @@ -86,23 +86,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): profiler.start_server(port) _, steps, train_ds, model = _model_setup() model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) + logging.info('worker finishing') + + def on_profile(port, logdir): + # Request for 30 milliseconds of profile. + duration_ms = 30 + + options = profiler.ProfilerOptions( + host_tracer_level=2, + python_tracer_level=0, + device_tracer_level=1, + ) + + profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, + '', 100, options) - port = portpicker.pick_unused_port() - thread = threading.Thread(target=on_worker, args=(port,)) - thread.start() - # Request for 3 seconds of profile. - duration_ms = 3000 logdir = self.get_temp_dir() - - options = profiler.ProfilerOptions( - host_tracer_level=2, - python_tracer_level=0, - device_tracer_level=1, - ) - - profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', - 3, options) - thread.join(30) + port = portpicker.pick_unused_port() + thread_profiler = threading.Thread(target=on_profile, args=(port, logdir)) + thread_worker = threading.Thread(target=on_worker, args=(port,)) + thread_worker.start() + thread_profiler.start() + thread_profiler.join() + thread_worker.join(120) self._check_tools_pb_exist(logdir) def test_single_worker_programmatic_mode(self):