Fix tsan failure in multi_process_runner_test.
PiperOrigin-RevId: 317747609 Change-Id: I8bf2e493431a69a0cf581012666045df2879055e
This commit is contained in:
parent
1c12a84e20
commit
5f9ae657ed
tensorflow/python/distribute
@ -1794,7 +1794,7 @@ py_test(
|
|||||||
name = "multi_process_runner_test",
|
name = "multi_process_runner_test",
|
||||||
srcs = ["multi_process_runner_test.py"],
|
srcs = ["multi_process_runner_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = ["notsan"], # TODO(b/158874970)
|
shard_count = 12,
|
||||||
deps = [
|
deps = [
|
||||||
":multi_process_runner",
|
":multi_process_runner",
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
|
@ -423,6 +423,18 @@ class MultiProcessRunner(object):
|
|||||||
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
||||||
"""Joins all the processes with timeout.
|
"""Joins all the processes with timeout.
|
||||||
|
|
||||||
|
If any of the subprocesses does not exit approximately after `timeout`
|
||||||
|
seconds has passed after `join` call, this raises a
|
||||||
|
`SubprocessTimeoutError`.
|
||||||
|
|
||||||
|
Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
|
||||||
|
log the stack traces of the subprocesses when they exit. However, this
|
||||||
|
results in timeout when the test runs with tsan (thread sanitizer); if tsan
|
||||||
|
is being run on the test targets that rely on timeout to assert information,
|
||||||
|
`MultiProcessRunner.terminate_all()` must be called after `join()`, before
|
||||||
|
the test exits, so the subprocesses are terminated with SIGKILL, and data
|
||||||
|
race is removed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: if set and not all processes report status within roughly
|
timeout: if set and not all processes report status within roughly
|
||||||
`timeout` seconds, a `SubprocessTimeoutError` exception will be raised.
|
`timeout` seconds, a `SubprocessTimeoutError` exception will be raised.
|
||||||
|
@ -124,24 +124,6 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
std_stream_results)
|
std_stream_results)
|
||||||
self.assertIn('This is returned data.', return_value)
|
self.assertIn('This is returned data.', return_value)
|
||||||
|
|
||||||
def test_process_that_exits(self):
|
|
||||||
|
|
||||||
def func_to_exit_in_25_sec():
|
|
||||||
logging.error('foo')
|
|
||||||
time.sleep(100)
|
|
||||||
logging.error('bar')
|
|
||||||
|
|
||||||
mpr = multi_process_runner.MultiProcessRunner(
|
|
||||||
func_to_exit_in_25_sec,
|
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
|
||||||
list_stdout=True,
|
|
||||||
max_run_time=25)
|
|
||||||
|
|
||||||
mpr.start()
|
|
||||||
stdout = mpr.join().stdout
|
|
||||||
self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
|
|
||||||
self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
|
|
||||||
|
|
||||||
def test_termination(self):
|
def test_termination(self):
|
||||||
|
|
||||||
def proc_func():
|
def proc_func():
|
||||||
@ -301,29 +283,21 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
def test_stdout_available_when_timeout(self):
|
def test_stdout_available_when_timeout(self):
|
||||||
|
|
||||||
def proc_func():
|
def proc_func():
|
||||||
for i in range(50):
|
logging.info('something printed')
|
||||||
logging.info('(logging) %s-%d, i: %d',
|
time.sleep(10000) # Intentionally make the test timeout.
|
||||||
multi_worker_test_base.get_task_type(), self._worker_idx(),
|
|
||||||
i)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
|
with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
|
||||||
multi_process_runner.run(
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
proc_func,
|
proc_func,
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
|
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||||
list_stdout=True,
|
list_stdout=True)
|
||||||
timeout=5)
|
mpr.start()
|
||||||
|
mpr.join(timeout=60)
|
||||||
|
mpr.terminate_all()
|
||||||
|
|
||||||
list_to_assert = cm.exception.mpr_result.stdout
|
list_to_assert = cm.exception.mpr_result.stdout
|
||||||
# We should see 5 iterations from worker and ps, however sometime on TAP
|
self.assertTrue(
|
||||||
# due to CPU throttling and slugginess of msan/asan build, this became
|
any('something printed' in line for line in list_to_assert))
|
||||||
# flaky. Therefore we allow more margin of errors to only check the first
|
|
||||||
# 3 iterations.
|
|
||||||
for job in ['worker', 'ps']:
|
|
||||||
for iteration in range(0, 3):
|
|
||||||
self.assertTrue(
|
|
||||||
any('(logging) {}-0, i: {}'.format(job, iteration) in line
|
|
||||||
for line in list_to_assert))
|
|
||||||
|
|
||||||
def test_seg_fault_raises_error(self):
|
def test_seg_fault_raises_error(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user