Improve multi_process_runner
This is to prepare enabling it for OSS. PiperOrigin-RevId: 315878874 Change-Id: Ib29bccf3c964462a7643df4b1cd011ddda79372b
This commit is contained in:
parent
51f2a966cc
commit
c674577870
|
@ -36,6 +36,7 @@ tensorflow/third_party/coremltools.BUILD
|
||||||
tensorflow/third_party/cub.BUILD
|
tensorflow/third_party/cub.BUILD
|
||||||
tensorflow/third_party/curl.BUILD
|
tensorflow/third_party/curl.BUILD
|
||||||
tensorflow/third_party/cython.BUILD
|
tensorflow/third_party/cython.BUILD
|
||||||
|
tensorflow/third_party/dill.BUILD
|
||||||
tensorflow/third_party/double_conversion.BUILD
|
tensorflow/third_party/double_conversion.BUILD
|
||||||
tensorflow/third_party/eigen.BUILD
|
tensorflow/third_party/eigen.BUILD
|
||||||
tensorflow/third_party/eigen3/BUILD
|
tensorflow/third_party/eigen3/BUILD
|
||||||
|
@ -196,6 +197,7 @@ tensorflow/third_party/systemlibs/swig.BUILD
|
||||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||||
|
tensorflow/third_party/tblib.BUILD
|
||||||
tensorflow/third_party/tensorrt/BUILD
|
tensorflow/third_party/tensorrt/BUILD
|
||||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||||
tensorflow/third_party/tensorrt/LICENSE
|
tensorflow/third_party/tensorrt/LICENSE
|
||||||
|
|
|
@ -1708,19 +1708,24 @@ cuda_py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "multi_process_runner",
|
name = "multi_process_runner",
|
||||||
srcs = ["multi_process_runner.py"],
|
srcs = ["multi_process_runner.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":multi_process_lib",
|
":multi_process_lib",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:tf2",
|
"//tensorflow/python:tf2",
|
||||||
"//tensorflow/python/compat:v2_compat",
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
"@absl_py//absl/logging",
|
||||||
|
"@dill_archive//:dill",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
|
"@tblib_archive//:tblib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "multi_process_lib",
|
name = "multi_process_lib",
|
||||||
srcs = ["multi_process_lib.py"],
|
srcs = ["multi_process_lib.py"],
|
||||||
deps = ["@six_archive//:six"],
|
deps = ["//tensorflow/python:client_testlib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -1745,11 +1750,12 @@ py_test(
|
||||||
name = "multi_process_runner_test",
|
name = "multi_process_runner_test",
|
||||||
srcs = ["multi_process_runner_test.py"],
|
srcs = ["multi_process_runner_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 12,
|
|
||||||
deps = [
|
deps = [
|
||||||
":multi_process_runner",
|
":multi_process_runner",
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
"@absl_py//absl/logging",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1757,6 +1763,7 @@ py_test(
|
||||||
name = "multi_process_runner_no_init_test",
|
name = "multi_process_runner_no_init_test",
|
||||||
srcs = ["multi_process_runner_no_init_test.py"],
|
srcs = ["multi_process_runner_no_init_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
tags = ["no_oss"],
|
||||||
deps = [
|
deps = [
|
||||||
":multi_process_runner",
|
":multi_process_runner",
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
|
|
|
@ -18,9 +18,18 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
import multiprocessing as _multiprocessing
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
multiprocessing = _multiprocessing.get_context('forkserver')
|
||||||
|
except ValueError:
|
||||||
|
# forkserver is not available on Windows.
|
||||||
|
multiprocessing = _multiprocessing.get_context('spawn')
|
||||||
|
|
||||||
|
|
||||||
class Process(object):
|
class Process(object):
|
||||||
"""A process simulating a worker for testing multi-worker training."""
|
"""A process simulating a worker for testing multi-worker training."""
|
||||||
|
@ -28,23 +37,14 @@ class Process(object):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
del args, kwargs
|
del args, kwargs
|
||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
'TODO(b/141874796): Implement OSS version of `multi_process_lib`')
|
'TODO(b/150264776): Implement OSS version of `multi_process_lib`')
|
||||||
|
|
||||||
|
|
||||||
def get_user_data():
|
def test_main():
|
||||||
"""Returns the data commonly shared by parent process and subprocesses."""
|
"""Main function to be called within `__main__` of a test file."""
|
||||||
# TODO(b/141874796): Implement OSS version of `multi_process_lib`.
|
test.main()
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
def initialized():
|
||||||
def context_manager(max_subprocess_count=20, barrier_parties=0):
|
"""Returns whether the module is initialized."""
|
||||||
"""No-op in OSS. This exists to maintain testing compatibility."""
|
return True
|
||||||
del max_subprocess_count, barrier_parties
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def using_context_manager():
|
|
||||||
"""Whether the context manager is being used."""
|
|
||||||
raise unittest.SkipTest(
|
|
||||||
'TODO(b/141874796): Implement OSS version of `multi_process_lib`')
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
|
@ -26,6 +27,8 @@ import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import six
|
import six
|
||||||
from six.moves import queue as Queue
|
from six.moves import queue as Queue
|
||||||
|
@ -34,7 +37,8 @@ from tensorflow.python import tf2
|
||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.distribute import multi_process_lib
|
from tensorflow.python.distribute import multi_process_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
multiprocessing = multi_process_lib.multiprocessing
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
try:
|
try:
|
||||||
|
@ -43,61 +47,57 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
faulthandler = None
|
faulthandler = None
|
||||||
|
|
||||||
|
# TODO(b/150264776): Remove after resolving CI issue.
|
||||||
|
try:
|
||||||
|
import dill
|
||||||
|
except ImportError:
|
||||||
|
dill = None
|
||||||
|
|
||||||
|
# TODO(b/150264776): Remove after resolving CI issue.
|
||||||
|
try:
|
||||||
|
import tblib.pickling_support
|
||||||
|
# For pickling traceback objects.
|
||||||
|
tblib.pickling_support.install()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# _ProcessStatusInfo contains process status information. When is_successful
|
# _ProcessStatusInfo contains process status information. When is_successful
|
||||||
# attribute is True, the subprocess has ended successfully, or if False, the
|
# attribute is True, the subprocess has ended successfully, or if False, the
|
||||||
# exception stack trace info is stored in exc_info to pass on to parent process
|
# exception stack trace info is stored in exc_info to pass on to parent process
|
||||||
# to be re-raised.
|
# to be re-raised.
|
||||||
_ProcessStatusInfo = collections.namedtuple(
|
_ProcessStatusInfo = collections.namedtuple(
|
||||||
'_ProcessStatusInfo', ['task_type', 'is_successful', 'exc_info'])
|
'_ProcessStatusInfo',
|
||||||
|
['task_type', 'is_successful', 'exc_info', 'return_value'])
|
||||||
# _SubprocessInfo collects basic information of a subprocess such as task type
|
|
||||||
# and process id.
|
|
||||||
_SubprocessInfo = collections.namedtuple('_SubprocessInfo',
|
|
||||||
['pid', 'task_type', 'task_id'])
|
|
||||||
|
|
||||||
# Information returned from a successful MultiProcessRunner run.
|
# Information returned from a successful MultiProcessRunner run.
|
||||||
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
|
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
|
||||||
['return_value', 'stdout'])
|
['return_value', 'stdout'])
|
||||||
|
|
||||||
# Process status queue is used by `multi_process_runner` internally for
|
TestEnvironment = collections.namedtuple('TestEnvironment', [
|
||||||
|
'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast',
|
||||||
|
'v2_enabled', 'executing_eagerly'
|
||||||
|
])
|
||||||
|
|
||||||
|
# Resources for communication between worker processes and the main process.
|
||||||
|
#
|
||||||
|
# `process_status_queue` is used by `multi_process_runner` internally for
|
||||||
# communication from subprocesses to the parent process for whether it's been
|
# communication from subprocesses to the parent process for whether it's been
|
||||||
# successful, and if not what the error stack trace is.
|
# successful, and if not what the error stack trace is.
|
||||||
PROCESS_STATUS_QUEUE = 'process_status_queue'
|
# `parent_to_sub_queue` is used for communications from parent to subprocess.
|
||||||
|
|
||||||
# Return value queue is intended to be used by users of `multi_process_runner`
|
|
||||||
# for the process function to return information to the caller of
|
|
||||||
# `multi_process_runner.run()`.
|
|
||||||
RETURN_VALUE_QUEUE = 'return_value_queue'
|
|
||||||
|
|
||||||
# Subprocess info queue stores `_SubprocessInfo` for later potential
|
|
||||||
# termination by the parent.
|
|
||||||
SUBPROCESS_INFO_QUEUE = 'subprocess_info_queue'
|
|
||||||
|
|
||||||
# Parent-to-sub queue is used for communications from parent to subprocess.
|
|
||||||
# Currently this is only used to terminate subprocesses.
|
# Currently this is only used to terminate subprocesses.
|
||||||
# TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
|
# TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
|
||||||
PARENT_TO_SUB_QUEUE = 'parent_to_sub_queue'
|
# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent
|
||||||
|
# process.
|
||||||
# Streaming queue stores the logged and printed messages from subprocesses.
|
# `barrier` is a barrier for the party of all subprocesses.
|
||||||
STREAMING_QUEUE = 'streaming_queue'
|
Resources = collections.namedtuple('Resources', [
|
||||||
|
'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
|
||||||
# Pipes to stream stdout and stderr from subprocesses to parent process.
|
])
|
||||||
STREAMING_PIPE = 'streaming_pipe'
|
|
||||||
|
|
||||||
# Barrier identifier.
|
|
||||||
BARRIER = 'barrier'
|
|
||||||
|
|
||||||
_DEFAULT_MAX_SUBPROCESS_COUNT = 20
|
|
||||||
|
|
||||||
# Default time out sec is selected so that it's handled before the default
|
# Default time out sec is selected so that it's handled before the default
|
||||||
# "medium" timeout of the test runs.
|
# "medium" timeout of the test runs.
|
||||||
_DEFAULT_TIMEOUT_SEC = 200
|
_DEFAULT_TIMEOUT_SEC = 200
|
||||||
|
|
||||||
# Next pipe index to be global so that pipes are not reused across multiple
|
|
||||||
# MultiProcessRunner usages.
|
|
||||||
# TODO(rchao): Investigate possibility to remove this variable.
|
|
||||||
_next_pipe_index = 0
|
|
||||||
|
|
||||||
|
|
||||||
class MultiProcessRunner(object):
|
class MultiProcessRunner(object):
|
||||||
"""A utility class to start multiple processes to simulate a cluster.
|
"""A utility class to start multiple processes to simulate a cluster.
|
||||||
|
@ -123,6 +123,7 @@ class MultiProcessRunner(object):
|
||||||
grpc_fail_fast=None,
|
grpc_fail_fast=None,
|
||||||
stream_stdout=True,
|
stream_stdout=True,
|
||||||
list_stdout=False,
|
list_stdout=False,
|
||||||
|
use_dill_for_args=True,
|
||||||
args=None,
|
args=None,
|
||||||
kwargs=None):
|
kwargs=None):
|
||||||
"""Creates a multi-process runner.
|
"""Creates a multi-process runner.
|
||||||
|
@ -153,6 +154,9 @@ class MultiProcessRunner(object):
|
||||||
returned from `MultiProcessRunner.join()`. If True, the list of stdout
|
returned from `MultiProcessRunner.join()`. If True, the list of stdout
|
||||||
can be retrieved via `MultiProcessRunnerResult.stdout` attribute.
|
can be retrieved via `MultiProcessRunnerResult.stdout` attribute.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
|
||||||
|
can pickle more objects, but doesn't work with types in
|
||||||
|
`multiprocessing` library like `Mutex`.
|
||||||
args: Positional arguments to be sent to functions run on processes.
|
args: Positional arguments to be sent to functions run on processes.
|
||||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||||
|
|
||||||
|
@ -165,15 +169,14 @@ class MultiProcessRunner(object):
|
||||||
raise ValueError('If chief exists in the cluster, there must be at most '
|
raise ValueError('If chief exists in the cluster, there must be at most '
|
||||||
'one chief. Current `cluster_spec` has {} chiefs.'
|
'one chief. Current `cluster_spec` has {} chiefs.'
|
||||||
.format(len(cluster_spec['chief'])))
|
.format(len(cluster_spec['chief'])))
|
||||||
|
if not multi_process_lib.initialized():
|
||||||
assert callable(proc_func)
|
|
||||||
|
|
||||||
if not multi_process_lib.using_context_manager():
|
|
||||||
raise RuntimeError('`multi_process_runner` is not initialized. '
|
raise RuntimeError('`multi_process_runner` is not initialized. '
|
||||||
'Please call `multi_process_runner.test_main()` '
|
'Please call `multi_process_runner.test_main()` '
|
||||||
'within `if __name__ == \'__main__\':` block '
|
'within `if __name__ == \'__main__\':` block '
|
||||||
'in your python module to properly initialize '
|
'in your python module to properly initialize '
|
||||||
'`multi_process_runner`.')
|
'`multi_process_runner`.')
|
||||||
|
if not callable(proc_func):
|
||||||
|
raise ValueError('proc_func is not a callable')
|
||||||
|
|
||||||
self._proc_func = proc_func
|
self._proc_func = proc_func
|
||||||
self._cluster_spec = cluster_spec
|
self._cluster_spec = cluster_spec
|
||||||
|
@ -184,62 +187,90 @@ class MultiProcessRunner(object):
|
||||||
# TODO(rchao): Revisit list_stdout argument to consider other solution.
|
# TODO(rchao): Revisit list_stdout argument to consider other solution.
|
||||||
self._list_stdout = list_stdout
|
self._list_stdout = list_stdout
|
||||||
self._dependence_on_chief = True
|
self._dependence_on_chief = True
|
||||||
|
self._use_dill_for_args = use_dill_for_args
|
||||||
self._args = args or ()
|
self._args = args or ()
|
||||||
self._kwargs = kwargs or {}
|
self._kwargs = kwargs or {}
|
||||||
|
|
||||||
self._outstanding_subprocess_count = 0
|
|
||||||
|
|
||||||
# Child processes should have the same v2 and eager behavior.
|
# Child processes should have the same v2 and eager behavior.
|
||||||
self._v2_enabled = tf2.enabled()
|
self._v2_enabled = tf2.enabled()
|
||||||
self._executing_eagerly = context.executing_eagerly()
|
self._executing_eagerly = context.executing_eagerly()
|
||||||
|
|
||||||
|
self._joined = False
|
||||||
|
self._processes = {}
|
||||||
|
self._outstanding_subprocess_count = 0
|
||||||
|
self._reading_threads = []
|
||||||
|
|
||||||
|
self._manager = multiprocessing.Manager()
|
||||||
|
self._process_status_queue = self._manager.Queue()
|
||||||
|
self._parent_to_sub_queue = self._manager.Queue()
|
||||||
|
parties = sum(len(addresses) for addresses in self._cluster_spec.values())
|
||||||
|
self._barrier = self._manager.Barrier(parties)
|
||||||
|
|
||||||
|
# We use a queue to collect outputs from worker processes since it's thread
|
||||||
|
# safe.
|
||||||
|
self._streaming_queue = self._manager.Queue()
|
||||||
|
|
||||||
# This flag will be set to True once terminate_all() is called.
|
# This flag will be set to True once terminate_all() is called.
|
||||||
self._all_forced_terminated = False
|
self._all_forced_terminated = False
|
||||||
|
|
||||||
def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
|
def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
|
||||||
"""Function to continuously read lines from subprocesses."""
|
"""Function to continuously read lines from subprocesses."""
|
||||||
reader = os.fdopen(pipe_r.fileno(), 'r')
|
with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader:
|
||||||
while True:
|
for line in reader:
|
||||||
read_line = reader.readline()
|
|
||||||
if read_line == 'EOF':
|
|
||||||
reader.close()
|
|
||||||
# The thread that runs `_continuously_readline_from_sub` stops here.
|
|
||||||
# However the threads don't exit until the test exits, so we do not
|
|
||||||
# attempt to join the threads (which leads to timeout).
|
|
||||||
# TODO(rchao): Understand why and do thread joining.
|
|
||||||
break
|
|
||||||
task_string = '[{}-{}]:'.format(task_type, task_id)
|
task_string = '[{}-{}]:'.format(task_type, task_id)
|
||||||
formatted_line = '{} {}'.format(task_string.ljust(14), read_line)
|
formatted_line = '{} {}'.format(task_string.ljust(14), line)
|
||||||
if self._stream_stdout:
|
if self._stream_stdout:
|
||||||
self._print_stdout_in_parent(formatted_line, task_type, task_id)
|
# TODO(rchao): Use a lock here to ensure the printed lines are not
|
||||||
if self._list_stdout:
|
# broken.
|
||||||
self._add_stdout_in_queue(formatted_line, task_type, task_id)
|
|
||||||
|
|
||||||
def _print_stdout_in_parent(self, formatted_line, task_type, task_id):
|
|
||||||
del task_type, task_id
|
|
||||||
# Flush True so the logging order from subprocesses is respected.
|
|
||||||
# TODO(rchao): Use a lock here to ensure the printed lines are not broken.
|
|
||||||
print(formatted_line, end='', flush=True)
|
print(formatted_line, end='', flush=True)
|
||||||
|
if self._list_stdout:
|
||||||
|
self._streaming_queue.put(formatted_line)
|
||||||
|
|
||||||
def _add_stdout_in_queue(self, formatted_line, task_type, task_id):
|
def _start_subprocess_and_reading_thread(self,
|
||||||
del task_type, task_id
|
task_type,
|
||||||
# A queue instead of a simple list is used here due to b/150652733.
|
task_id,
|
||||||
_resource(STREAMING_QUEUE).put(formatted_line)
|
cluster_spec=None,
|
||||||
|
proc_func=None,
|
||||||
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id,
|
args=None,
|
||||||
cluster_spec, args, kwargs):
|
kwargs=None):
|
||||||
"""Start a subprocess and a thread the reads lines from the subprocess."""
|
"""Start a subprocess and a thread the reads lines from the subprocess."""
|
||||||
global _next_pipe_index
|
|
||||||
pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index]
|
|
||||||
_next_pipe_index += 1
|
|
||||||
|
|
||||||
p = multi_process_lib.Process(
|
if dill is None:
|
||||||
target=_Subprocess(),
|
raise unittest.SkipTest(
|
||||||
args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer,
|
'TODO(b/150264776): Resolve dependency issue in CI')
|
||||||
self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly,
|
|
||||||
pipe_w) + args,
|
test_env = TestEnvironment(
|
||||||
kwargs=kwargs)
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
|
cluster_spec=cluster_spec or self._cluster_spec,
|
||||||
|
rpc_layer=self._rpc_layer,
|
||||||
|
grpc_fail_fast=self._grpc_fail_fast,
|
||||||
|
v2_enabled=self._v2_enabled,
|
||||||
|
executing_eagerly=self._executing_eagerly,
|
||||||
|
)
|
||||||
|
pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
|
||||||
|
resources = Resources(
|
||||||
|
process_status_queue=self._process_status_queue,
|
||||||
|
parent_to_sub_queue=self._parent_to_sub_queue,
|
||||||
|
streaming_pipe_w=pipe_w,
|
||||||
|
barrier=self._barrier,
|
||||||
|
)
|
||||||
|
if proc_func is None:
|
||||||
|
proc_func, args, kwargs = self._proc_func, self._args, self._kwargs
|
||||||
|
# Always use dill to pickle proc_func so that we support more callable
|
||||||
|
# types, e.g. lambda.
|
||||||
|
proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL)
|
||||||
|
if self._use_dill_for_args:
|
||||||
|
args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
|
||||||
|
kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
p = _Process(
|
||||||
|
test_env=test_env,
|
||||||
|
target=_ProcFunc(),
|
||||||
|
args=(resources, test_env, proc_func, args, kwargs,
|
||||||
|
self._use_dill_for_args))
|
||||||
p.start()
|
p.start()
|
||||||
|
self._processes[(task_type, task_id)] = p
|
||||||
self._outstanding_subprocess_count += 1
|
self._outstanding_subprocess_count += 1
|
||||||
|
|
||||||
# For each subprocess, we dedicate a thread continuously reading lines
|
# For each subprocess, we dedicate a thread continuously reading lines
|
||||||
|
@ -248,18 +279,15 @@ class MultiProcessRunner(object):
|
||||||
target=self._continuously_readline_from_sub,
|
target=self._continuously_readline_from_sub,
|
||||||
args=(pipe_r, task_type, task_id))
|
args=(pipe_r, task_type, task_id))
|
||||||
thread.start()
|
thread.start()
|
||||||
|
self._reading_threads.append(thread)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts processes, one for each task in `cluster_spec`."""
|
"""Starts processes, one for each task in `cluster_spec`."""
|
||||||
|
if self._processes:
|
||||||
global _next_pipe_index
|
raise ValueError('MultiProcessRunner already started.')
|
||||||
self._starting_pipe_index = _next_pipe_index
|
|
||||||
|
|
||||||
for task_type, addresses in self._cluster_spec.items():
|
for task_type, addresses in self._cluster_spec.items():
|
||||||
for task_id, _ in enumerate(addresses):
|
for task_id, _ in enumerate(addresses):
|
||||||
self._start_subprocess_and_reading_thread(self._proc_func, task_type,
|
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||||
task_id, self._cluster_spec,
|
|
||||||
self._args, self._kwargs)
|
|
||||||
|
|
||||||
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
|
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
|
||||||
# without this the tests become very flaky.
|
# without this the tests become very flaky.
|
||||||
|
@ -309,33 +337,22 @@ class MultiProcessRunner(object):
|
||||||
as_task_type: The task type to be run in the main process.
|
as_task_type: The task type to be run in the main process.
|
||||||
as_task_id: The task id to be run in the main process.
|
as_task_id: The task id to be run in the main process.
|
||||||
"""
|
"""
|
||||||
global _next_pipe_index
|
if self._processes:
|
||||||
self._starting_pipe_index = _next_pipe_index
|
raise ValueError('MultiProcessRunner already started.')
|
||||||
|
|
||||||
for task_type, addresses in self._cluster_spec.items():
|
for task_type, addresses in self._cluster_spec.items():
|
||||||
for task_id, _ in enumerate(addresses):
|
for task_id, _ in enumerate(addresses):
|
||||||
if not (task_type == as_task_type and task_id == as_task_id):
|
if not (task_type == as_task_type and task_id == as_task_id):
|
||||||
self._start_subprocess_and_reading_thread(self._proc_func, task_type,
|
self._start_subprocess_and_reading_thread(task_type, task_id)
|
||||||
task_id, self._cluster_spec,
|
|
||||||
self._args, self._kwargs)
|
|
||||||
tf_config_dict = {
|
|
||||||
'cluster': self._cluster_spec,
|
|
||||||
'task': {
|
|
||||||
'type': as_task_type,
|
|
||||||
'index': as_task_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if self._rpc_layer is not None:
|
|
||||||
tf_config_dict['rpc_layer'] = self._rpc_layer
|
|
||||||
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
|
|
||||||
|
|
||||||
|
_set_tf_config(as_task_type, as_task_id, self._cluster_spec,
|
||||||
|
self._rpc_layer)
|
||||||
self._proc_func(*self._args, **self._kwargs)
|
self._proc_func(*self._args, **self._kwargs)
|
||||||
|
|
||||||
def start_single_process(self,
|
def start_single_process(self,
|
||||||
task_type,
|
task_type,
|
||||||
task_id,
|
task_id,
|
||||||
proc_func=None,
|
|
||||||
cluster_spec=None,
|
cluster_spec=None,
|
||||||
|
proc_func=None,
|
||||||
args=None,
|
args=None,
|
||||||
kwargs=None):
|
kwargs=None):
|
||||||
"""Starts a single process.
|
"""Starts a single process.
|
||||||
|
@ -352,19 +369,22 @@ class MultiProcessRunner(object):
|
||||||
Args:
|
Args:
|
||||||
task_type: The task type.
|
task_type: The task type.
|
||||||
task_id: The task id.
|
task_id: The task id.
|
||||||
proc_func: The process function to be run on the newly started
|
|
||||||
process. If `None`, the function provided at `__init__` will be used.
|
|
||||||
cluster_spec: The cluster spec to be used on the newly started
|
cluster_spec: The cluster spec to be used on the newly started
|
||||||
process. If `None`, the cluster spec provided at `__init__` will be
|
process. If `None`, the cluster spec provided at `__init__` will be
|
||||||
used.
|
used.
|
||||||
|
proc_func: The process function to be run on the newly started
|
||||||
|
process. If specified, specify `args` and `kwargs` as well. If `None`,
|
||||||
|
the function provided at `__init__` will be used.
|
||||||
args: Optional positional arguments to be supplied in `proc_func`.
|
args: Optional positional arguments to be supplied in `proc_func`.
|
||||||
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
kwargs: Optional keyword arguments to be supplied in `proc_func`.
|
||||||
"""
|
"""
|
||||||
cluster_spec = cluster_spec or self._cluster_spec
|
self._start_subprocess_and_reading_thread(
|
||||||
proc_func = proc_func or self._proc_func
|
task_type,
|
||||||
self._start_subprocess_and_reading_thread(proc_func, task_type, task_id,
|
task_id,
|
||||||
cluster_spec, args or (),
|
cluster_spec=cluster_spec,
|
||||||
kwargs or {})
|
proc_func=proc_func,
|
||||||
|
args=args or (),
|
||||||
|
kwargs=kwargs or {})
|
||||||
|
|
||||||
def _queue_to_list(self, queue_to_convert):
|
def _queue_to_list(self, queue_to_convert):
|
||||||
"""Convert `queue.Queue` to `list`."""
|
"""Convert `queue.Queue` to `list`."""
|
||||||
|
@ -379,25 +399,20 @@ class MultiProcessRunner(object):
|
||||||
|
|
||||||
def get_process_id(self, task_type, task_id):
|
def get_process_id(self, task_type, task_id):
|
||||||
"""Returns the subprocess id given the task type and task id."""
|
"""Returns the subprocess id given the task type and task id."""
|
||||||
if not hasattr(self, '_pid_dict'):
|
p = self._processes.get((task_type, task_id), None)
|
||||||
self._pid_dict = {}
|
return p.pid if p else None
|
||||||
subprocess_infos = []
|
|
||||||
|
|
||||||
while True:
|
def _join_or_terminate(self, task_type, task_id, process, timeout):
|
||||||
try:
|
"""Joins a process. If it times out, terminate all procsses."""
|
||||||
subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False)
|
logging.info('joining %s-%d', task_type, task_id)
|
||||||
subprocess_infos.append(subprocess_info)
|
process.join(timeout)
|
||||||
except Queue.Empty:
|
# If exitcode is None, the process aren't terminated and this is a
|
||||||
break
|
# timeout.
|
||||||
|
if process.exitcode is None:
|
||||||
for subprocess_info in subprocess_infos:
|
# Force termination to dump worker processes stack trace.
|
||||||
self._pid_dict[(subprocess_info.task_type,
|
self.terminate_all(sig=signal.SIGTERM)
|
||||||
subprocess_info.task_id)] = subprocess_info.pid
|
raise RuntimeError('%s-%d and possibly more subprocesses timed out.' %
|
||||||
|
(task_type, task_id))
|
||||||
for subprocess_info in subprocess_infos:
|
|
||||||
_resource(SUBPROCESS_INFO_QUEUE).put(subprocess_info)
|
|
||||||
|
|
||||||
return self._pid_dict.get((task_type, task_id), None)
|
|
||||||
|
|
||||||
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
|
||||||
"""Joins all the processes with timeout.
|
"""Joins all the processes with timeout.
|
||||||
|
@ -417,84 +432,97 @@ class MultiProcessRunner(object):
|
||||||
RuntimeError: if not all processes report status approximatelty within
|
RuntimeError: if not all processes report status approximatelty within
|
||||||
`timeout` seconds, or there's an exception propagated from any subprocess.
|
`timeout` seconds, or there's an exception propagated from any subprocess.
|
||||||
"""
|
"""
|
||||||
|
if self._joined:
|
||||||
|
raise ValueError("MultiProcessRunner can't be joined twice.")
|
||||||
|
self._joined = True
|
||||||
|
|
||||||
if not timeout:
|
chief = self._processes.get(('chief', 0), None)
|
||||||
timeout = float('inf')
|
if self._dependence_on_chief and chief:
|
||||||
start_time = time.time()
|
self._join_or_terminate('chief', 0, chief, timeout)
|
||||||
while self._outstanding_subprocess_count > 0:
|
# Give other processes a chance to exit on their own.
|
||||||
try:
|
for p in self._processes.values():
|
||||||
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
|
p.join(timeout=3)
|
||||||
|
self.terminate_all()
|
||||||
|
else:
|
||||||
|
for (task_type, task_id), p in self._processes.items():
|
||||||
|
self._join_or_terminate(task_type, task_id, p, timeout)
|
||||||
|
|
||||||
self._outstanding_subprocess_count -= 1
|
for (task_type, task_id), p in self._processes.items():
|
||||||
|
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
|
||||||
|
|
||||||
|
process_statuses = self._queue_to_list(self._process_status_queue)
|
||||||
|
if not self._all_forced_terminated and len(
|
||||||
|
process_statuses) != self._outstanding_subprocess_count:
|
||||||
|
raise RuntimeError(
|
||||||
|
'missing statuses from %d subproceses.' %
|
||||||
|
(self._outstanding_subprocess_count - len(process_statuses)))
|
||||||
|
return_values = []
|
||||||
|
for process_status in process_statuses:
|
||||||
assert isinstance(process_status, _ProcessStatusInfo)
|
assert isinstance(process_status, _ProcessStatusInfo)
|
||||||
if not process_status.is_successful:
|
if not process_status.is_successful:
|
||||||
six.reraise(*process_status.exc_info)
|
six.reraise(*process_status.exc_info)
|
||||||
|
if process_status.return_value is not None:
|
||||||
|
return_values.append(process_status.return_value)
|
||||||
|
|
||||||
if self._dependence_on_chief and process_status.task_type == 'chief':
|
logging.info('Joining log reading threads.')
|
||||||
self.terminate_all()
|
for thread in self._reading_threads:
|
||||||
break
|
thread.join()
|
||||||
except Queue.Empty:
|
logging.info('Joined log reading threads.')
|
||||||
if self._all_forced_terminated:
|
|
||||||
break
|
|
||||||
if time.time() - start_time > timeout:
|
|
||||||
# Send SIGTERM signal to subprocesses to dump their current
|
|
||||||
# stack trace.
|
|
||||||
self.terminate_all(sig=signal.SIGTERM)
|
|
||||||
# If none of those did, report timeout to user.
|
|
||||||
raise RuntimeError('One or more subprocesses timed out. '
|
|
||||||
'Number of outstanding subprocesses '
|
|
||||||
'is %d.' % self._outstanding_subprocess_count)
|
|
||||||
|
|
||||||
# Giving threads some time to finish the message reading from subprocesses.
|
# Clear the alarm.
|
||||||
time.sleep(5)
|
signal.alarm(0)
|
||||||
|
|
||||||
stdout = self._queue_to_list(_resource(STREAMING_QUEUE))
|
stdout = self._queue_to_list(self._streaming_queue)
|
||||||
return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE))
|
|
||||||
|
|
||||||
# Notifying the threads that are reading lines that we should stop.
|
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
|
||||||
for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access
|
|
||||||
_, pipe_w = _resource(STREAMING_PIPE)[pipe_index]
|
|
||||||
writer = os.fdopen(pipe_w.fileno(), 'w')
|
|
||||||
# Writing end of file message so the threads that's actively reading lines
|
|
||||||
# know to stop.
|
|
||||||
writer.writelines(['EOF'])
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
return MultiProcessRunnerResult(stdout=stdout, return_value=return_value)
|
|
||||||
|
|
||||||
def terminate(self, task_type, task_id):
|
def terminate(self, task_type, task_id):
|
||||||
"""Terminates the process with `task_type` and `task_id`."""
|
"""Terminates the process with `task_type` and `task_id`."""
|
||||||
_resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format(
|
p = self._processes.get((task_type, task_id), None)
|
||||||
task_type, task_id))
|
if p is None:
|
||||||
|
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
|
||||||
|
# TODO(crccw): change to use Process.terminate() as well.
|
||||||
|
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
|
||||||
|
p.join()
|
||||||
|
|
||||||
def terminate_all(self, sig=None):
|
def terminate_all(self, sig=None):
|
||||||
"""Terminates all subprocesses."""
|
"""Terminates all subprocesses."""
|
||||||
# Use SIGKILL as default. In systems where that's unavailable such as
|
# Use SIGKILL as default. In systems where that's unavailable such as
|
||||||
# windows, use SIGTERM.
|
# windows, use SIGTERM.
|
||||||
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
|
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
|
||||||
subprocess_infos = []
|
for (task_type, task_id), p in self._processes.items():
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False)
|
os.kill(p.pid, sig)
|
||||||
subprocess_infos.append(subprocess_info)
|
|
||||||
except Queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
for subprocess_info in subprocess_infos:
|
|
||||||
logging.info('Parent process is now killing PID: %d', subprocess_info.pid)
|
|
||||||
try:
|
|
||||||
os.kill(subprocess_info.pid, sig)
|
|
||||||
except ProcessLookupError:
|
except ProcessLookupError:
|
||||||
# TODO(rchao): Remove subprocess info from the queue once a subprocess
|
logging.info('Attempting to kill %s-%d but it does not exist.',
|
||||||
# is terminated.
|
task_type, task_id)
|
||||||
logging.info('PID %d does not exist.', subprocess_info.pid)
|
|
||||||
|
|
||||||
self._all_forced_terminated = True
|
self._all_forced_terminated = True
|
||||||
|
|
||||||
|
|
||||||
class _Subprocess(object):
|
class _Process(multi_process_lib.Process):
|
||||||
"""Represents an internal subprocess used in MultiProcessRunner's context."""
|
"""A modified `multiprocessing.Process` that can set up environment variables."""
|
||||||
|
|
||||||
|
# TODO(crccw): consider moving other logics in _ProcFunc to _Process.
|
||||||
|
|
||||||
|
def __init__(self, test_env, **kwargs):
|
||||||
|
super(_Process, self).__init__(**kwargs)
|
||||||
|
self._test_env = test_env
|
||||||
|
self._actual_run = getattr(self, 'run')
|
||||||
|
self.run = self._run_with_setenv
|
||||||
|
|
||||||
|
def _run_with_setenv(self):
|
||||||
|
# We need to set environment variables before doing anything because
|
||||||
|
# setenv() is not thread-safe.
|
||||||
|
test_env = self._test_env
|
||||||
|
if test_env.grpc_fail_fast is not None:
|
||||||
|
os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
|
||||||
|
_set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
|
||||||
|
test_env.rpc_layer)
|
||||||
|
return self._actual_run()
|
||||||
|
|
||||||
|
|
||||||
|
class _ProcFunc(object):
|
||||||
|
"""Represents a callable to run in a subprocess."""
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _runtime_mode(self, executing_eagerly):
|
def _runtime_mode(self, executing_eagerly):
|
||||||
|
@ -505,21 +533,12 @@ class _Subprocess(object):
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def _finish_process(self, process_status_info, return_value):
|
|
||||||
"""Adds data to queues before program exits."""
|
|
||||||
# Clear the alarm.
|
|
||||||
signal.alarm(0)
|
|
||||||
|
|
||||||
if return_value is not None:
|
|
||||||
self._add_return_data(return_value)
|
|
||||||
_resource(PROCESS_STATUS_QUEUE).put(process_status_info)
|
|
||||||
|
|
||||||
def _message_checking_func(self, task_type, task_id):
|
def _message_checking_func(self, task_type, task_id):
|
||||||
"""A function that regularly checks messages from parent process."""
|
"""A function that regularly checks messages from parent process."""
|
||||||
# TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
|
# TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message = _resource(PARENT_TO_SUB_QUEUE).get(block=False)
|
message = self._resources.parent_to_sub_queue.get(block=False)
|
||||||
|
|
||||||
# Currently the only possible message is termination.
|
# Currently the only possible message is termination.
|
||||||
if not message.startswith('terminate'):
|
if not message.startswith('terminate'):
|
||||||
|
@ -530,63 +549,75 @@ class _Subprocess(object):
|
||||||
else:
|
else:
|
||||||
# If the message is not targeting this process, put it back to the
|
# If the message is not targeting this process, put it back to the
|
||||||
# queue.
|
# queue.
|
||||||
_resource(PARENT_TO_SUB_QUEUE).put(message)
|
self._resources.parent_to_sub_queue.put(message)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
except Queue.Empty:
|
except Queue.Empty:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self._finish_process(
|
self._resources.process_status_queue.put(
|
||||||
_ProcessStatusInfo(
|
_ProcessStatusInfo(
|
||||||
task_type=task_type, is_successful=True, exc_info=None), None)
|
task_type=task_type,
|
||||||
|
is_successful=True,
|
||||||
|
exc_info=None,
|
||||||
|
return_value=None))
|
||||||
# `os._exit(0)` is used to more reliably terminate a subprocess.
|
# `os._exit(0)` is used to more reliably terminate a subprocess.
|
||||||
os._exit(0) # pylint: disable=protected-access
|
os._exit(0) # pylint: disable=protected-access
|
||||||
|
|
||||||
def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec,
|
def _close_streaming(self):
|
||||||
rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w,
|
"""Close stdout, stderr and streaming pipe.
|
||||||
*arg, **kwargs):
|
|
||||||
|
We need to explicitly close them since Tensorflow may take a while to exit,
|
||||||
|
so that the reading threads in the main process can exit more quickly.
|
||||||
|
"""
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
sys.stdout.close()
|
||||||
|
sys.stderr.close()
|
||||||
|
self._resources.streaming_pipe_w.close()
|
||||||
|
|
||||||
|
def __call__(self, resources, test_env, proc_func, args, kwargs,
|
||||||
|
use_dill_for_args):
|
||||||
"""The wrapper function that actually gets run in child process(es)."""
|
"""The wrapper function that actually gets run in child process(es)."""
|
||||||
|
|
||||||
|
global _barrier
|
||||||
|
|
||||||
|
self._resources = resources
|
||||||
|
_barrier = self._resources.barrier
|
||||||
|
proc_func = dill.loads(proc_func)
|
||||||
|
if use_dill_for_args:
|
||||||
|
args = dill.loads(args)
|
||||||
|
kwargs = dill.loads(kwargs)
|
||||||
|
|
||||||
if faulthandler is not None:
|
if faulthandler is not None:
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
faulthandler.register(signal.SIGTERM, chain=True)
|
faulthandler.register(signal.SIGTERM, chain=True)
|
||||||
|
|
||||||
|
# All logging should go to stderr to be streamed to the main process.
|
||||||
|
logging.set_stderrthreshold(logging.DEBUG)
|
||||||
|
|
||||||
|
# Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
|
||||||
|
# print() and logging.*() write directly to `streaming_pipe_w`.
|
||||||
|
# Unfortunately since we cannot prepend task_type and task_id information to
|
||||||
|
# the streamed logs we will need a thread per subprocess to distinguish
|
||||||
|
# where the piece of message is from.
|
||||||
|
os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
|
||||||
|
os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
|
||||||
|
|
||||||
pid = os.getpid()
|
pid = os.getpid()
|
||||||
logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
|
logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
|
||||||
task_type, task_id)
|
test_env.task_type, test_env.task_id)
|
||||||
_resource(SUBPROCESS_INFO_QUEUE).put(
|
|
||||||
_SubprocessInfo(pid=pid, task_type=task_type, task_id=task_id))
|
|
||||||
# Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and
|
|
||||||
# logging.*() write directly to `pipe_w`. Unfortunately since we cannot
|
|
||||||
# prepend task_type and task_id information to the streamed logs we will
|
|
||||||
# need a thread per subprocess to distinguish where the piece of message is
|
|
||||||
# from.
|
|
||||||
os.dup2(pipe_w.fileno(), sys.stdout.fileno())
|
|
||||||
os.dup2(pipe_w.fileno(), sys.stderr.fileno())
|
|
||||||
|
|
||||||
# The thread will be dedicated to checking messages from the parent process.
|
# The thread will be dedicated to checking messages from the parent process.
|
||||||
threading.Thread( # pylint: disable=unexpected-keyword-arg
|
threading.Thread( # pylint: disable=unexpected-keyword-arg
|
||||||
target=self._message_checking_func,
|
target=self._message_checking_func,
|
||||||
args=(task_type, task_id),
|
args=(test_env.task_type, test_env.task_id),
|
||||||
daemon=True).start()
|
daemon=True).start()
|
||||||
|
|
||||||
if grpc_fail_fast is not None:
|
if test_env.v2_enabled:
|
||||||
os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast)
|
|
||||||
tf_config_dict = {
|
|
||||||
'cluster': per_process_cluster_spec,
|
|
||||||
'task': {
|
|
||||||
'type': task_type,
|
|
||||||
'index': task_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if rpc_layer is not None:
|
|
||||||
tf_config_dict['rpc_layer'] = rpc_layer
|
|
||||||
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
|
|
||||||
|
|
||||||
if v2_enabled:
|
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self._runtime_mode(executing_eagerly):
|
with self._runtime_mode(test_env.executing_eagerly):
|
||||||
return_value = proc_func(*arg, **kwargs)
|
return_value = proc_func(*args, **kwargs)
|
||||||
is_successful = True
|
is_successful = True
|
||||||
exc_info = None
|
exc_info = None
|
||||||
|
|
||||||
|
@ -606,35 +637,27 @@ class _Subprocess(object):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._finish_process(
|
info = _ProcessStatusInfo(
|
||||||
_ProcessStatusInfo(
|
task_type=test_env.task_type,
|
||||||
task_type=task_type,
|
|
||||||
is_successful=is_successful,
|
is_successful=is_successful,
|
||||||
exc_info=exc_info),
|
exc_info=exc_info,
|
||||||
return_value)
|
return_value=return_value)
|
||||||
|
self._resources.process_status_queue.put(info)
|
||||||
def _add_return_data(self, data):
|
self._close_streaming()
|
||||||
"""Adds return data that will be returned by `join`.
|
|
||||||
|
|
||||||
The function provides a way for child processes to communicate with the
|
|
||||||
parent process. Data passed to `_add_return_data` will be available in a
|
|
||||||
Python Queue.Queue that is eventually returned by `join`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: data to be made available in the queue returned by `join`.
|
|
||||||
"""
|
|
||||||
# TODO(rchao): Incorporate the task type and id information in a data
|
|
||||||
# wrapper that becomes what is stored in the queue so we can tell where
|
|
||||||
# the data is from.
|
|
||||||
_resource(RETURN_VALUE_QUEUE).put(data)
|
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None):
|
||||||
return multi_process_lib.get_user_data()[BARRIER]
|
"""Set TF_CONFIG environment variable."""
|
||||||
|
tf_config_dict = {
|
||||||
|
'cluster': cluster_spec,
|
||||||
def _resource(resource_name):
|
'task': {
|
||||||
return multi_process_lib.get_user_data()[resource_name]
|
'type': task_type,
|
||||||
|
'index': task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if rpc_layer is not None:
|
||||||
|
tf_config_dict['rpc_layer'] = rpc_layer
|
||||||
|
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
|
||||||
|
|
||||||
|
|
||||||
def run(proc_func,
|
def run(proc_func,
|
||||||
|
@ -670,16 +693,19 @@ def run(proc_func,
|
||||||
return runner.join(timeout)
|
return runner.join(timeout)
|
||||||
|
|
||||||
|
|
||||||
def test_main(max_subprocess_count=_DEFAULT_MAX_SUBPROCESS_COUNT,
|
# This is set by MultiProcessRunner in worker processes.
|
||||||
barrier_parties=0):
|
_barrier = None
|
||||||
"""Main function to be called within `__main__` of a test file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_subprocess_count: Maximum number of subprocesses that will be used. User
|
def barrier():
|
||||||
of multi_process_runner needs to determine a number at calling this
|
if _barrier is None:
|
||||||
method, and the subprocesses involved later should not exceed this number.
|
raise ValueError(
|
||||||
barrier_parties: Number of parties the barrier will be used toward. User of
|
'barrier is not defined. It is likely because you are calling barrier()'
|
||||||
multi_process_runner needs to determine a number at calling this method.
|
'in the main process. barrier() can only be called in the subprocesses.'
|
||||||
"""
|
)
|
||||||
with multi_process_lib.context_manager(max_subprocess_count, barrier_parties):
|
return _barrier
|
||||||
test.main()
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
"""Main function to be called within `__main__` of a test file."""
|
||||||
|
multi_process_lib.test_main()
|
||||||
|
|
|
@ -23,15 +23,13 @@ import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from absl import logging
|
from absl import logging
|
||||||
from six.moves import queue as Queue
|
|
||||||
|
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
|
||||||
|
|
||||||
def proc_func_that_adds_task_type_in_return_data(test_obj, val):
|
def proc_func_that_adds_task_type_in_return_data():
|
||||||
test_obj.assertEqual(val, 3)
|
|
||||||
return multi_worker_test_base.get_task_type()
|
return multi_worker_test_base.get_task_type()
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +49,10 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
|
||||||
return list(args) + list(kwargs.items())
|
return list(args) + list(kwargs.items())
|
||||||
|
|
||||||
|
|
||||||
|
def proc_func_with_barrier():
|
||||||
|
return multi_process_runner.barrier()
|
||||||
|
|
||||||
|
|
||||||
class MultiProcessRunnerTest(test.TestCase):
|
class MultiProcessRunnerTest(test.TestCase):
|
||||||
|
|
||||||
def _worker_idx(self):
|
def _worker_idx(self):
|
||||||
|
@ -61,8 +63,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||||
mpr_result = multi_process_runner.run(
|
mpr_result = multi_process_runner.run(
|
||||||
proc_func_that_adds_task_type_in_return_data,
|
proc_func_that_adds_task_type_in_return_data,
|
||||||
multi_worker_test_base.create_cluster_spec(
|
multi_worker_test_base.create_cluster_spec(
|
||||||
num_workers=2, num_ps=3, has_eval=1),
|
num_workers=2, num_ps=3, has_eval=1))
|
||||||
args=(self, 3))
|
|
||||||
|
|
||||||
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
|
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
|
||||||
for data in mpr_result.return_value:
|
for data in mpr_result.return_value:
|
||||||
|
@ -124,36 +125,22 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||||
|
|
||||||
def test_process_that_exits(self):
|
def test_process_that_exits(self):
|
||||||
|
|
||||||
def func_to_exit_in_15_sec():
|
def func_to_exit_in_5_sec():
|
||||||
time.sleep(5)
|
logging.error('foo')
|
||||||
print('foo', flush=True)
|
time.sleep(10)
|
||||||
time.sleep(20)
|
logging.error('bar')
|
||||||
print('bar', flush=True)
|
|
||||||
|
|
||||||
mpr = multi_process_runner.MultiProcessRunner(
|
mpr = multi_process_runner.MultiProcessRunner(
|
||||||
func_to_exit_in_15_sec,
|
func_to_exit_in_5_sec,
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||||
list_stdout=True,
|
list_stdout=True,
|
||||||
max_run_time=15)
|
max_run_time=5)
|
||||||
|
|
||||||
mpr.start()
|
mpr.start()
|
||||||
stdout = mpr.join().stdout
|
stdout = mpr.join().stdout
|
||||||
self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
|
self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
|
||||||
self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
|
self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
|
||||||
|
|
||||||
def test_signal_doesnt_fire_after_process_exits(self):
|
|
||||||
mpr = multi_process_runner.MultiProcessRunner(
|
|
||||||
proc_func_that_does_nothing,
|
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
|
||||||
max_run_time=10)
|
|
||||||
mpr.start()
|
|
||||||
mpr.join()
|
|
||||||
with self.assertRaisesRegexp(Queue.Empty, ''):
|
|
||||||
# If the signal was fired, another message would be added to internal
|
|
||||||
# queue, so verifying it's empty.
|
|
||||||
multi_process_runner._resource(
|
|
||||||
multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False)
|
|
||||||
|
|
||||||
def test_termination(self):
|
def test_termination(self):
|
||||||
|
|
||||||
def proc_func():
|
def proc_func():
|
||||||
|
@ -192,7 +179,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
||||||
list_stdout=True)
|
list_stdout=True)
|
||||||
mpr.start()
|
mpr.start()
|
||||||
time.sleep(5)
|
time.sleep(3)
|
||||||
mpr.terminate('worker', 0)
|
mpr.terminate('worker', 0)
|
||||||
mpr.start_single_process('worker', 0)
|
mpr.start_single_process('worker', 0)
|
||||||
std_stream_results = mpr.join().stdout
|
std_stream_results = mpr.join().stdout
|
||||||
|
@ -273,11 +260,14 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||||
has_chief=True, num_workers=1),
|
has_chief=True, num_workers=1),
|
||||||
list_stdout=True)
|
list_stdout=True)
|
||||||
|
|
||||||
def follow_ups():
|
def eval_func():
|
||||||
|
time.sleep(1)
|
||||||
mpr.start_single_process(task_type='evaluator', task_id=0)
|
mpr.start_single_process(task_type='evaluator', task_id=0)
|
||||||
|
|
||||||
threading.Thread(target=follow_ups).start()
|
eval_thread = threading.Thread(target=eval_func)
|
||||||
|
eval_thread.start()
|
||||||
mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
|
mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
|
||||||
|
eval_thread.join()
|
||||||
list_to_assert = mpr.join().stdout
|
list_to_assert = mpr.join().stdout
|
||||||
for job in ['worker', 'evaluator']:
|
for job in ['worker', 'evaluator']:
|
||||||
for iteration in range(5):
|
for iteration in range(5):
|
||||||
|
@ -291,9 +281,21 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||||
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
||||||
list_stdout=True)
|
list_stdout=True)
|
||||||
mpr.start()
|
mpr.start()
|
||||||
|
time.sleep(3)
|
||||||
mpr.terminate_all()
|
mpr.terminate_all()
|
||||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||||
mpr.join()
|
mpr.join()
|
||||||
|
|
||||||
|
def test_barrier(self):
|
||||||
|
multi_process_runner.run(
|
||||||
|
proc_func_with_barrier,
|
||||||
|
cluster_spec=multi_worker_test_base.create_cluster_spec(
|
||||||
|
has_chief=True, num_workers=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_barrier_called_in_main_process(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
multi_process_runner.barrier()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
multi_process_runner.test_main()
|
multi_process_runner.test_main()
|
||||||
|
|
|
@ -127,4 +127,4 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
multi_process_runner.test_main(barrier_parties=NUM_WORKERS)
|
multi_process_runner.test_main()
|
||||||
|
|
|
@ -364,7 +364,7 @@ py_test(
|
||||||
name = "multi_worker_callback_tf2_test",
|
name = "multi_worker_callback_tf2_test",
|
||||||
srcs = ["multi_worker_callback_tf2_test.py"],
|
srcs = ["multi_worker_callback_tf2_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 5,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||||
"//tensorflow/python/distribute:combinations",
|
"//tensorflow/python/distribute:combinations",
|
||||||
|
|
|
@ -345,4 +345,4 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
multi_process_runner.test_main(barrier_parties=2)
|
multi_process_runner.test_main()
|
||||||
|
|
|
@ -183,6 +183,7 @@ filegroup(
|
||||||
"@com_google_protobuf//:LICENSE",
|
"@com_google_protobuf//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
|
"@dill_archive//:LICENSE",
|
||||||
"@dlpack//:LICENSE",
|
"@dlpack//:LICENSE",
|
||||||
"@double_conversion//:LICENSE",
|
"@double_conversion//:LICENSE",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
|
@ -212,6 +213,7 @@ filegroup(
|
||||||
"@six_archive//:LICENSE",
|
"@six_archive//:LICENSE",
|
||||||
"@snappy//:COPYING",
|
"@snappy//:COPYING",
|
||||||
"@sobol_data//:LICENSE",
|
"@sobol_data//:LICENSE",
|
||||||
|
"@tblib_archive//:LICENSE",
|
||||||
"@termcolor_archive//:COPYING.txt",
|
"@termcolor_archive//:COPYING.txt",
|
||||||
"@zlib//:zlib.h",
|
"@zlib//:zlib.h",
|
||||||
"@clog//:LICENSE",
|
"@clog//:LICENSE",
|
||||||
|
|
|
@ -532,6 +532,28 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_http_archive(
|
||||||
|
name = "dill_archive",
|
||||||
|
build_file = clean_dep("//third_party:dill.BUILD"),
|
||||||
|
urls = [
|
||||||
|
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz",
|
||||||
|
"https://files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz",
|
||||||
|
],
|
||||||
|
sha256 = "42d8ef819367516592a825746a18073ced42ca169ab1f5f4044134703e7a049c",
|
||||||
|
strip_prefix = "dill-0.3.1.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_http_archive(
|
||||||
|
name = "tblib_archive",
|
||||||
|
build_file = clean_dep("//third_party:tblib.BUILD"),
|
||||||
|
urls = [
|
||||||
|
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz",
|
||||||
|
"https://files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz",
|
||||||
|
],
|
||||||
|
sha256 = "436e4200e63d92316551179dc540906652878df4ff39b43db30fcf6400444fe7",
|
||||||
|
strip_prefix = "tblib-1.3.2",
|
||||||
|
)
|
||||||
|
|
||||||
filegroup_external(
|
filegroup_external(
|
||||||
name = "org_python_license",
|
name = "org_python_license",
|
||||||
licenses = ["notice"], # Python 2.0
|
licenses = ["notice"], # Python 2.0
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
licenses(["notice"]) # BSD 3-clause
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dill",
|
||||||
|
srcs = glob(["dill/*.py"]),
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
|
@ -0,0 +1,11 @@
|
||||||
|
licenses(["notice"]) # BSD
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tblib",
|
||||||
|
srcs = glob(["src/tblib/*.py"]),
|
||||||
|
imports = ["src"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
Loading…
Reference in New Issue