diff --git a/tensorflow/python/profiler/integration_test/profiler_api_test.py b/tensorflow/python/profiler/integration_test/profiler_api_test.py index 1d79b660ba5..3a603e7ae71 100644 --- a/tensorflow/python/profiler/integration_test/profiler_api_test.py +++ b/tensorflow/python/profiler/integration_test/profiler_api_test.py @@ -59,6 +59,11 @@ def _make_temp_log_dir(test_obj): class ProfilerApiTest(test_util.TensorFlowTestCase): + def setUp(self): + super().setUp() + self.worker_start = threading.Event() + self.profile_done = False + def _check_tools_pb_exist(self, logdir): expected_files = [ 'overview_page.pb', @@ -86,16 +91,21 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): def test_single_worker_sampling_mode(self, delay_ms=None): """Test single worker sampling mode.""" - def on_worker(port): + def on_worker(port, worker_start): logging.info('worker starting server on {}'.format(port)) profiler.start_server(port) _, steps, train_ds, model = _model_setup() - model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) + worker_start.set() + while True: + model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) + if self.profile_done: + break - def on_profile(port, logdir): + def on_profile(port, logdir, worker_start): # Request for 30 milliseconds of profile. duration_ms = 30 + worker_start.wait() options = profiler.ProfilerOptions( host_tracer_level=2, python_tracer_level=0, @@ -106,16 +116,29 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', 100, options) + self.profile_done = True + logdir = self.get_temp_dir() port = portpicker.pick_unused_port() - thread_profiler = threading.Thread(target=on_profile, args=(port, logdir)) - thread_worker = threading.Thread(target=on_worker, args=(port,)) + thread_profiler = threading.Thread( + target=on_profile, args=(port, logdir, self.worker_start)) + thread_worker = threading.Thread( + target=on_worker, args=(port, self.worker_start)) thread_worker.start() thread_profiler.start() thread_profiler.join() thread_worker.join(120) self._check_xspace_pb_exist(logdir) + def test_single_worker_sampling_mode_short_delay(self): + """Test single worker sampling mode with a short delay. + + Expect that requested delayed start time will arrive late, and a subsequent + retry will issue an immediate start. + """ + + self.test_single_worker_sampling_mode(delay_ms=1) + def test_single_worker_sampling_mode_long_delay(self): """Test single worker sampling mode with a long delay."""