diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index f0f3766afe1..74d80b63e12 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1794,7 +1794,7 @@ py_test( name = "multi_process_runner_test", srcs = ["multi_process_runner_test.py"], python_version = "PY3", - tags = ["notsan"], # TODO(b/158874970) + shard_count = 12, deps = [ ":multi_process_runner", ":multi_worker_test_base", diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index af527b67b4b..cb460c8fff5 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -423,6 +423,18 @@ class MultiProcessRunner(object): def join(self, timeout=_DEFAULT_TIMEOUT_SEC): """Joins all the processes with timeout. + If any of the subprocesses does not exit approximately after `timeout` + seconds has passed after `join` call, this raises a + `SubprocessTimeoutError`. + + Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to + log the stack traces of the subprocesses when they exit. However, this + results in timeout when the test runs with tsan (thread sanitizer); if tsan + is being run on the test targets that rely on timeout to assert information, + `MultiProcessRunner.terminate_all()` must be called after `join()`, before + the test exits, so the subprocesses are terminated with SIGKILL, and data + race is removed. + Args: timeout: if set and not all processes report status within roughly `timeout` seconds, a `SubprocessTimeoutError` exception will be raised. diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index 6194ac527d5..529d7fd91a5 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -124,24 +124,6 @@ class MultiProcessRunnerTest(test.TestCase): std_stream_results) self.assertIn('This is returned data.', return_value) - def test_process_that_exits(self): - - def func_to_exit_in_25_sec(): - logging.error('foo') - time.sleep(100) - logging.error('bar') - - mpr = multi_process_runner.MultiProcessRunner( - func_to_exit_in_25_sec, - multi_worker_test_base.create_cluster_spec(num_workers=1), - list_stdout=True, - max_run_time=25) - - mpr.start() - stdout = mpr.join().stdout - self.assertLen([msg for msg in stdout if 'foo' in msg], 1) - self.assertLen([msg for msg in stdout if 'bar' in msg], 0) - def test_termination(self): def proc_func(): @@ -301,29 +283,21 @@ class MultiProcessRunnerTest(test.TestCase): def test_stdout_available_when_timeout(self): def proc_func(): - for i in range(50): - logging.info('(logging) %s-%d, i: %d', - multi_worker_test_base.get_task_type(), self._worker_idx(), - i) - time.sleep(1) + logging.info('something printed') + time.sleep(10000) # Intentionally make the test timeout. with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm: - multi_process_runner.run( + mpr = multi_process_runner.MultiProcessRunner( proc_func, - multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1), - list_stdout=True, - timeout=5) + multi_worker_test_base.create_cluster_spec(num_workers=1), + list_stdout=True) + mpr.start() + mpr.join(timeout=60) + mpr.terminate_all() list_to_assert = cm.exception.mpr_result.stdout - # We should see 5 iterations from worker and ps, however sometime on TAP - # due to CPU throttling and slugginess of msan/asan build, this became - # flaky. Therefore we allow more margin of errors to only check the first - # 3 iterations. - for job in ['worker', 'ps']: - for iteration in range(0, 3): - self.assertTrue( - any('(logging) {}-0, i: {}'.format(job, iteration) in line - for line in list_to_assert)) + self.assertTrue( + any('something printed' in line for line in list_to_assert)) def test_seg_fault_raises_error(self):