STT-tensorflow/tensorflow/python/distribute/multi_process_runner_test.py
Ran Chen e357988ab3 Re-enable multi process pool runner tests
PiperOrigin-RevId: 339084496
Change-Id: I0085dcc86ddcf8ff237234bc15e50815587442b5
2020-10-26 11:58:41 -07:00

607 lines
19 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `multi_process_runner`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ctypes
import json
import os
import sys
import threading
import time
import unittest
from absl import logging
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import test
def fn_that_adds_task_type_in_return_data():
return multi_worker_test_base.get_task_type()
def fn_that_errors():
raise ValueError('This is an error.')
def fn_that_does_nothing():
pass
def fn_that_adds_simple_return_data():
return 'dummy_data'
def fn_that_returns_args_and_kwargs(*args, **kwargs):
return list(args) + list(kwargs.items())
def fn_with_barrier():
return multi_process_runner.get_barrier()
def fn_that_returns_pid():
return os.getpid()
V = None
def fn_that_sets_global(val):
global V
old_val = V
V = val
return old_val
class MultiProcessRunnerTest(test.TestCase):
def _worker_idx(self):
config_task = json.loads(os.environ['TF_CONFIG'])['task']
return config_task['index']
def test_multi_process_runner(self):
mpr_result = multi_process_runner.run(
fn_that_adds_task_type_in_return_data,
multi_worker_test_base.create_cluster_spec(
num_workers=2, num_ps=3, has_chief=True))
job_count_dict = {'worker': 2, 'ps': 3, 'chief': 1}
for data in mpr_result.return_value:
job_count_dict[data] -= 1
self.assertEqual(job_count_dict['worker'], 0)
self.assertEqual(job_count_dict['ps'], 0)
self.assertEqual(job_count_dict['chief'], 0)
def test_multi_process_runner_error_propagates_from_subprocesses(self):
runner = multi_process_runner.MultiProcessRunner(
fn_that_errors,
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
max_run_time=20)
runner.start()
with self.assertRaisesRegex(ValueError, 'This is an error.'):
runner.join()
def test_multi_process_runner_queue_emptied_between_runs(self):
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
return_value = multi_process_runner.run(fn_that_adds_simple_return_data,
cluster_spec).return_value
self.assertTrue(return_value)
self.assertEqual(return_value[0], 'dummy_data')
self.assertEqual(return_value[1], 'dummy_data')
return_value = multi_process_runner.run(fn_that_does_nothing,
cluster_spec).return_value
self.assertFalse(return_value)
def test_multi_process_runner_args_passed_correctly(self):
return_value = multi_process_runner.run(
fn_that_returns_args_and_kwargs,
multi_worker_test_base.create_cluster_spec(num_workers=1),
args=('a', 'b'),
kwargs={
'c_k': 'c_v'
}).return_value
self.assertEqual(return_value[0][0], 'a')
self.assertEqual(return_value[0][1], 'b')
self.assertEqual(return_value[0][2], ('c_k', 'c_v'))
def test_stdout_captured(self):
def simple_print_func():
print('This is something printed.', flush=True)
return 'This is returned data.'
mpr_result = multi_process_runner.run(
simple_print_func,
multi_worker_test_base.create_cluster_spec(num_workers=2),
return_output=True)
std_stream_results = mpr_result.stdout
return_value = mpr_result.return_value
self.assertIn('[worker-0]: This is something printed.\n',
std_stream_results)
self.assertIn('[worker-1]: This is something printed.\n',
std_stream_results)
self.assertIn('This is returned data.', return_value)
def test_termination(self):
def fn():
for i in range(0, 10):
print(
'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
time.sleep(5)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=2),
return_output=True)
mpr.start()
time.sleep(5)
mpr.terminate('worker', 0)
std_stream_results = mpr.join().stdout
# Worker 0 is terminated in the middle, so it should not have iteration 9
# printed.
self.assertIn('[worker-0]: index 0, iteration 0\n', std_stream_results)
self.assertNotIn('[worker-0]: index 0, iteration 9\n',
std_stream_results)
self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results)
self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results)
def test_termination_and_start_single_process(self):
def fn():
for i in range(0, 10):
print(
'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
time.sleep(1)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=2),
return_output=True)
mpr.start()
time.sleep(3)
mpr.terminate('worker', 0)
mpr.start_single_process('worker', 0)
std_stream_results = mpr.join().stdout
# Worker 0 is terminated in the middle, but a new worker 0 is added, so it
# should still have iteration 9 printed. Moreover, iteration 0 of worker 0
# should happen twice.
self.assertLen(
[s for s in std_stream_results if 'index 0, iteration 0' in s], 2)
self.assertIn('[worker-0]: index 0, iteration 9\n', std_stream_results)
self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results)
self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results)
def test_streaming(self):
def fn():
for i in range(5):
logging.info('(logging) %s-%d, i: %d',
multi_worker_test_base.get_task_type(), self._worker_idx(),
i)
print(
'(print) {}-{}, i: {}'.format(
multi_worker_test_base.get_task_type(), self._worker_idx(), i),
flush=True)
time.sleep(1)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=2, num_ps=2),
return_output=True)
mpr._dependence_on_chief = False
mpr.start()
mpr.start_single_process('worker', 2)
mpr.start_single_process('ps', 2)
mpr_result = mpr.join()
list_to_assert = mpr_result.stdout
for job in ['chief']:
for iteration in range(5):
self.assertTrue(
any('(logging) {}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert))
self.assertTrue(
any('(print) {}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert))
for job in ['worker', 'ps']:
for iteration in range(5):
for task in range(3):
self.assertTrue(
any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line
for line in list_to_assert))
self.assertTrue(
any('(print) {}-{}, i: {}'.format(job, task, iteration) in line
for line in list_to_assert))
task = 3
self.assertFalse(
any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line
for line in list_to_assert))
self.assertFalse(
any('(print) {}-{}, i: {}'.format(job, task, iteration) in line
for line in list_to_assert))
def test_start_in_process_as(self):
def fn():
for i in range(5):
logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(),
self._worker_idx(), i)
time.sleep(1)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
return_output=True)
def eval_func():
time.sleep(1)
mpr.start_single_process(task_type='evaluator', task_id=0)
eval_thread = threading.Thread(target=eval_func)
eval_thread.start()
mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
eval_thread.join()
list_to_assert = mpr.join().stdout
for job in ['worker', 'evaluator']:
for iteration in range(5):
self.assertTrue(
any('{}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert))
def test_terminate_all_does_not_ignore_error(self):
mpr = multi_process_runner.MultiProcessRunner(
fn_that_errors,
multi_worker_test_base.create_cluster_spec(num_workers=2),
return_output=True)
mpr.start()
time.sleep(60)
mpr.terminate_all()
with self.assertRaisesRegex(ValueError, 'This is an error.'):
mpr.join()
def test_barrier(self):
multi_process_runner.run(
fn_with_barrier,
cluster_spec=multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
)
def test_barrier_called_in_main_process(self):
with self.assertRaises(ValueError):
multi_process_runner.get_barrier()
def test_stdout_available_when_timeout(self):
def fn():
logging.info('something printed')
time.sleep(10000) # Intentionally make the test timeout.
with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=1),
return_output=True)
mpr.start()
mpr.join(timeout=60)
mpr.terminate_all()
list_to_assert = cm.exception.mpr_result.stdout
self.assertTrue(
any('something printed' in line for line in list_to_assert))
def test_seg_fault_raises_error(self):
if multi_process_runner.is_oss():
self.skipTest('TODO(b/171004637): Failing in OSS')
def fn_expected_to_seg_fault():
ctypes.string_at(0) # Intentionally made seg fault.
with self.assertRaises(
multi_process_runner.UnexpectedSubprocessExitError) as cm:
multi_process_runner.run(
fn_expected_to_seg_fault,
multi_worker_test_base.create_cluster_spec(num_workers=1),
return_output=True)
self.assertIn('Subprocess worker-0 exited with exit code',
str(cm.exception))
list_to_assert = cm.exception.mpr_result.stdout
self.assertTrue(
any('Segmentation fault' in line for line in list_to_assert))
def test_seg_fault_in_chief_raises_error(self):
if multi_process_runner.is_oss():
self.skipTest('TODO(b/171004637): Failing in OSS')
def fn_expected_to_seg_fault():
if multi_worker_test_base.get_task_type() == 'worker':
time.sleep(10000)
ctypes.string_at(0) # Intentionally made seg fault.
with self.assertRaises(
multi_process_runner.UnexpectedSubprocessExitError) as cm:
multi_process_runner.run(
fn_expected_to_seg_fault,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
return_output=True)
self.assertIn('Subprocess chief-0 exited with exit code',
str(cm.exception))
list_to_assert = cm.exception.mpr_result.stdout
self.assertTrue(
any('Segmentation fault' in line for line in list_to_assert))
def test_exit_code_is_reported_by_chief_subprocess(self):
def fn_expected_to_exit_with_20():
if multi_worker_test_base.get_task_type() == 'worker':
time.sleep(10000)
sys.exit(20)
mpr = multi_process_runner.MultiProcessRunner(
fn_expected_to_exit_with_20,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1))
mpr.start()
with self.assertRaisesRegex(
multi_process_runner.UnexpectedSubprocessExitError,
'Subprocess chief-0 exited with exit code 20'):
mpr.join()
def test_exit_code_is_reported_by_subprocess(self):
def fn_expected_to_exit_with_10():
sys.exit(10)
mpr = multi_process_runner.MultiProcessRunner(
fn_expected_to_exit_with_10,
multi_worker_test_base.create_cluster_spec(num_workers=1))
mpr.start()
with self.assertRaisesRegex(
multi_process_runner.UnexpectedSubprocessExitError,
'Subprocess worker-0 exited with exit code 10'):
mpr.join()
def test_auto_restart(self):
def fn(counter):
counter.value += 1
if counter.value == 1:
raise ValueError
manager = multi_process_runner.manager()
counter = manager.Value(int, 0)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=1),
args=(counter,),
auto_restart=True)
mpr.start()
mpr.join()
self.assertEqual(counter.value, 2)
def test_auto_restart_and_timeout(self):
def fn():
logging.info('Running')
time.sleep(1)
raise ValueError
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=1),
auto_restart=True,
return_output=True)
mpr.start()
with self.assertRaises(ValueError) as cm:
mpr.join(timeout=10)
self.assertGreater(
sum(['Running' in msg for msg in cm.exception.mpr_result.stdout]), 1)
def test_auto_restart_and_chief(self):
# If the chief has exited with zero exit code, auto restart should stop
# restarting other tasks even if they fail.
def fn():
time.sleep(1)
if multi_worker_test_base.get_task_type() != 'chief':
raise ValueError
manager = multi_process_runner.manager()
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
auto_restart=True)
mpr.start()
with self.assertRaises(ValueError):
mpr.join(timeout=10)
def test_auto_restart_failure_immediate_after_restart(self):
# Test the case when worker-0 fails immediately after worker-1 restarts.
def fn():
time.sleep(5)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(
has_chief=False, num_workers=2),
auto_restart=True)
mpr.start()
pid = mpr.get_process_id('worker', 1)
mpr.terminate('worker', 1)
while mpr.get_process_id('worker', 1) == pid:
time.sleep(0.1)
mpr.terminate('worker', 0)
mpr.join(timeout=20)
def test_auto_restart_terminate(self):
# Tasks terminated by the user should also be restarted.
def fn(counter):
counter.value += 1
if counter.value == 1:
time.sleep(100)
manager = multi_process_runner.manager()
counter = manager.Value(int, 0)
mpr = multi_process_runner.MultiProcessRunner(
fn,
multi_worker_test_base.create_cluster_spec(
has_chief=False, num_workers=1),
args=(counter,),
auto_restart=True)
mpr.start()
time.sleep(3)
mpr.terminate('worker', 0)
mpr.join(timeout=20)
self.assertEqual(counter.value, 2)
def test_error_reporting_overrides_timeout_reporting(self):
def fn():
if self._worker_idx() == 1:
time.sleep(10000)
raise ValueError('Worker 0 errored')
mpr = multi_process_runner.MultiProcessRunner(
fn, multi_worker_test_base.create_cluster_spec(num_workers=2))
mpr.start()
with self.assertRaisesRegex(
ValueError,
'Worker 0 errored'):
mpr.join(timeout=20)
def test_process_exists(self):
def fn():
time.sleep(100000)
mpr = multi_process_runner.MultiProcessRunner(
fn, multi_worker_test_base.create_cluster_spec(num_workers=1))
mpr.start()
self.assertTrue(mpr.process_exists('worker', 0))
mpr.terminate('worker', 0)
# Worker 0 should exit at some point, or else the test would time out.
while mpr.process_exists('worker', 0):
time.sleep(1)
def test_timeout_none(self):
if multi_process_runner.is_oss():
self.skipTest('Intentionally skipping longer test in OSS.')
def fn():
time.sleep(250)
raise ValueError('Worker 0 errored')
mpr = multi_process_runner.MultiProcessRunner(
fn, multi_worker_test_base.create_cluster_spec(num_workers=1))
mpr.start()
with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
mpr.join(timeout=None)
_global_pool = multi_process_runner.MultiProcessPoolRunner(
multi_worker_test_base.create_cluster_spec(num_workers=2))
class MultiProcessPoolRunnerTest(test.TestCase):
def test_same_process_across_runs(self):
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
pid = runner.run(fn_that_returns_pid)
for _ in range(3):
self.assertAllEqual(runner.run(fn_that_returns_pid), pid)
def test_exceptions_in_sub_process(self):
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
pid = runner.run(fn_that_returns_pid)
with self.assertRaisesRegex(ValueError, 'This is an error.'):
runner.run(fn_that_errors)
self.assertAllEqual(runner.run(fn_that_returns_pid), pid)
def test_tf_config(self):
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=2)
runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
result = runner.run(fn_that_adds_task_type_in_return_data)
job_count_dict = {'worker': 2, 'chief': 1}
for data in result:
job_count_dict[data] -= 1
self.assertEqual(job_count_dict['worker'], 0)
self.assertEqual(job_count_dict['chief'], 0)
@unittest.expectedFailure
def test_exception_in_main_process(self):
# When there's an exception in the main process, __del__() is not called.
# This test is to verify MultiProcessPoolRunner can cope with __del__() not
# being called.
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=2)
runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
runner.run(fn_that_returns_pid)
raise ValueError('failure')
def test_initializer(self):
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
runner = multi_process_runner.MultiProcessPoolRunner(
cluster_spec, initializer=lambda: fn_that_sets_global(1))
result = runner.run(fn_that_sets_global, args=(2,))
self.assertAllEqual(result, [1, 1])
def test_global_pool(self):
_global_pool.run(fn_that_does_nothing)
def test_nested_pool(self):
def fn():
# This runs in sub processes, so they are each using their own
# MultiProcessPoolRunner.
_global_pool.run(fn_that_does_nothing)
_global_pool.run(fn)
if __name__ == '__main__':
multi_process_runner.test_main()