diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 69353e0fcb2..d273d7176b3 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1321,6 +1321,8 @@ py_library(
         ":multi_process_runner_util",
         ":multi_worker_test_base",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:tf2",
+        "//tensorflow/python/compat:v2_compat",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py
index 253c17d1c25..c57741eadd1 100644
--- a/tensorflow/python/distribute/multi_process_runner.py
+++ b/tensorflow/python/distribute/multi_process_runner.py
@@ -24,12 +24,13 @@ import json
 import os
 import signal
 import sys
+import time
 
-from absl import flags
 import six
 from six.moves import queue as Queue
 
-from tensorflow.python.distribute import multi_worker_test_base
+from tensorflow.python import tf2
+from tensorflow.python.compat import v2_compat
 from tensorflow.python.distribute import multi_process_lib
 from tensorflow.python.eager import context
 from tensorflow.python.platform import test
@@ -37,23 +38,20 @@ from tensorflow.python.platform import test
 _FINISH_PROPERLY_MESSAGE = 'OK'
 _ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info'])
 
-
-class _AvailableQueues(object):
-  """Names of the available queues used by `multi_process_runner`."""
-  # Internal queue is used by `multi_process_runner` internally for
-  # communication from subprocesses to the parent process. The message
-  # can be _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended successfully, or
-  # the detailed message of an exception if the subprocess has raised
-  # one so it can be re-raised by the parent process.
-  INTERNAL_QUEUE = 'internal_queue'
-  # Public queue is intended to be used by users of `multi_process_runner`
-  # for the process function to return information to the caller of
-  # `multi_process_runner.run()`.
-  PUBLIC_QUEUE = 'public_queue'
-  # Standard stream queue is used by `multi_process_runner` to collect
-  # information streamed to stdout and stderr to be reported back to the
-  # parent process.
-  STD_STREAM_QUEUE = 'std_stream_queue'
+# Process status queue is used by `multi_process_runner` internally for
+# communication from subprocesses to the parent process. The message can be
+# _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended
+# successfully, or the detailed message of an exception if the subprocess has
+# raised one so it can be re-raised by the parent process.
+PROCESS_STATUS_QUEUE = 'process_status_queue'
+# Return value queue is intended to be used by users of `multi_process_runner`
+# for the process function to return information to the caller of
+# `multi_process_runner.run()`.
+RETURN_VALUE_QUEUE = 'return_value_queue'
+# Standard stream queue is used by `multi_process_runner` to collect
+# information streamed to stdout and stderr to be reported back to the
+# parent process.
+STD_STREAM_QUEUE = 'std_stream_queue'
 
 
 class _LogCollector(object):
@@ -72,29 +70,32 @@ class _LogCollector(object):
 
 
 class MultiProcessRunner(object):
-  """A utility to start multiple subprocesses to simulate multiple workers.
+  """A utility class to start multiple processes to simulate a cluster.
 
-  Training with multiple workers with eager runtime can be tested by simulating
-  using multiple processes. See `run()` for more information about the usage
-  of this class.
+  We need to use multiple processes to simulate a cluster in TF 2.0 tests
+  because TF 2.0 has some process-global data structures that have to be
+  separated by processes. We also need child processes to test out our fault
+  tolerance because shutting down a standard TensorFlow server within its
+  process is not supported.
+
+  Note: the main test program that uses this runner class must run main program
+  via `test_main` defined in this file. Using this runner in non-test binaries
+  is not supported yet.
+
+  This class is not thread-safe. Child processes will inherit TF2 behavior flag.
   """
 
-  def run(self,
-          proc_func,
-          cluster_spec,
-          proc_flags=None,
-          timeout=200,
-          time_to_exit=None,
-          return_std_stream=False,
-          args=None,
-          kwargs=None):
-    """Run functions on local sub-processes.
-
-    Experimental. API subject to change. To fully inspect logging from
-    subprocesses, use `--test_arg=--logtostderr` flag with bazel test.
+  def __init__(self,
+               proc_func,
+               cluster_spec,
+               max_run_time=None,
+               capture_std_stream=False,
+               args=None,
+               kwargs=None):
+    """Creates a multi-process runner.
 
     Args:
-      proc_func: Function to be run on the processes. This will be run on
+      proc_func: Function to be run on child processes. This will be run on
         processes for all task types.
       cluster_spec: Dict for cluster spec. The following is an example of
         cluster with three workers and two ps's.
@@ -103,39 +104,19 @@ class MultiProcessRunner(object):
                     "worker2.example.com:2222"],
          "ps": ["ps0.example.com:2222",
                 "ps1.example.com:2222"]}
-      proc_flags: Dict that contains the key/values of the flags used on the
-        processes.
-      timeout: Time out in seconds. If the sub-process takes more than this time
-        to complete, raise an error.
-      time_to_exit: If set, sub-processes is forced to exit at approximately
-        this many seconds after `run()` is called, through `signal.alarm()` api.
-        This is for simulation of interruption on a process so in such cases no
-        error is raised. Note that this is best effort at Python level since
-        Python signal handler does not get executed inside the low-level (C)
-        signal handler, so it can be delayed.
-      return_std_stream: Boolean, whether the messages streamed to stdout and
-        stderr in subprocesses are captured. If True, the messages are stored in
-        a list returned as the second element.
+      max_run_time: If set, child processes is forced to exit at approximately
+        this many seconds after `start` is called. We achieve this through
+        `signal.alarm()` api. Note that this is best effort at Python level
+        since Python signal handler does not get executed when it runs lower
+        level C/C++ code. So it can be delayed for arbitrarily long time.
+      capture_std_stream: Boolean, whether the messages streamed to stdout and
+        stderr in subprocesses are captured.
       args: Positional arguments to be sent to functions run on processes.
       kwargs: Keyword arguments to be sent to functions run on processes.
 
-    Returns:
-      If `return_std_stream` is False, a list that stores the return data added
-      by subprocesses through `multi_process_runner._add_return_data(data)`
-      call,
-      or through normal function return; if `return_std_stream` is True, a
-      two-element tuple of `(return_data_list, std_stream_data_list)`, where
-      `return_data_list` stores the return data added by processes through
-      `multi_process_runner._add_return_data(data)` call or through normal
-      function
-      return, and `std_stream_data_list` stores the messages streamed to stdout
-      and stderr in the subprocesses.
-
     Raises:
-      RuntimeError: If any of the subprocesses raise an error, or if any of the
-        subprocesses does not return or error out within `timeout` seconds.
+      RuntimeError: if `multi_process_runner.test_main()` is not called.
     """
-
     assert cluster_spec is not None
     assert callable(proc_func)
 
@@ -146,192 +127,233 @@ class MultiProcessRunner(object):
                          'in your python module to properly initialize '
                          '`multi_process_runner`.')
 
-    processes = []
-    args = args or ()
-    kwargs = kwargs or {}
+    self._proc_func = proc_func
+    self._cluster_spec = cluster_spec
+    self._max_run_time = max_run_time
+    self._capture_std_stream = capture_std_stream
+    self._args = args or ()
+    self._kwargs = kwargs or {}
+    self._processes = []
 
-    def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
-                     executing_eagerly, *arg, **kwargs):
-      """The wrapper function that actually gets run on the process(es)."""
+    # Child processes should have the same v2 and eager behavior.
+    self._v2_enabled = tf2.enabled()
+    self._executing_eagerly = context.executing_eagerly()
 
-      @contextlib.contextmanager
-      def runtime_mode(executing_eagerly):
-        if executing_eagerly:
-          with context.eager_mode():
-            yield
-        else:
-          with context.graph_mode():
-            yield
-
-      with runtime_mode(executing_eagerly):
-        os.environ['TF_CONFIG'] = tf_config_as_json
-        if proc_flags is not None:
-          for flag_key, flag_value in proc_flags.items():
-            setattr(flags.FLAGS, flag_key, flag_value)
-
-        stdout_collector = _LogCollector(
-            sys.__stdout__) if return_std_stream else None
-        stderr_collector = _LogCollector(
-            sys.__stderr__) if return_std_stream else None
-
-        def finish_wrapper_func_properly(func_result):
-          """Call to finish `wrapper_func` properly."""
-          # Clear the alarm.
-          signal.alarm(0)
-          if (return_std_stream and stdout_collector is not None and
-              stderr_collector is not None):
-            # If stdout and stderr are to be collected, add them to std stream
-            # queue.
-            self._add_std_stream_data_flattened(stdout_collector.log)
-            self._add_std_stream_data_flattened(stderr_collector.log)
-            # Un-redirect stdout and stderr.
-            sys.stdout = sys.__stdout__
-            sys.stderr = sys.__stderr__
-          self._get_internal_queue().put(func_result)
-
-        if time_to_exit is not None:
-
-          def handler(signum, frame):
-            del signum, frame
-            finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
-            # pylint: disable=protected-access
-            os._exit(0)
-
-          signal.signal(signal.SIGALRM, handler)
-          signal.alarm(time_to_exit)
-
-        if return_std_stream:
-          sys.stdout = stdout_collector
-          sys.stderr = stderr_collector
-
-        try:
-          return_data = proc_func(*arg, **kwargs)
-          if return_data is not None:
-            self._add_return_data(return_data)
-        # pylint: disable=broad-except
-        except Exception:
-          # Capture all exceptions to be reported to parent process.
-          finish_wrapper_func_properly(_ExcInfoWrapper(sys.exc_info()))
-
-          # Re-raise the exception in addition to reporting it to the parent
-          # process, so that even if `--test_timeout` flag is set and the
-          # error doesn't make it to be shown in parent process before bazel's
-          # timeout, the log would still show what happens in this subprocess,
-          # instead of silently suppressing the error due to early bazel
-          # timeout. Raising an error in the subprocess produces stack trace in
-          # the log, but the program continues running.
-          raise
-
-        finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
-
-    # Start number of processes according to `count_dict`.
-    for job_type, addresses in cluster_spec.items():
-      for task_id, _ in enumerate(addresses):
-        tf_config_as_json = json.dumps({
-            'cluster': cluster_spec,
-            'task': {
-                'type': job_type,
-                'index': task_id
-            }
-        })
-        p = multi_process_lib.Process(
-            target=wrapper_func,
-            args=(tf_config_as_json, proc_func, proc_flags, time_to_exit,
-                  context.executing_eagerly()) + args,
-            kwargs=kwargs)
-        p.start()
-        processes.append(p)
-
-    internal_queue_results = []
-    for _ in range(len(processes)):
-      try:
-        internal_queue_results.append(
-            self._get_internal_queue().get(timeout=timeout))
-      except Queue.Empty:
-        # First check if any of the subprocesses raised exception.
-        for internal_queue_result in internal_queue_results:
-          if isinstance(internal_queue_result, _ExcInfoWrapper):
-            six.reraise(*internal_queue_result.exc_info)
-        # If none of those did, report time out to user.
-        raise RuntimeError(
-            'One or more subprocesses timed out. Please use '
-            '`--test_arg=--logtostderr` bazel flag to inspect logs for '
-            'subprocess debugging info. Timeout = {} sec.'.format(timeout))
-
-    for internal_queue_result in internal_queue_results:
-      if isinstance(internal_queue_result, _ExcInfoWrapper):
-        six.reraise(*internal_queue_result.exc_info)
-      assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
-
-    def queue_to_list(queue_to_convert):
-      """Convert `queue.Queue` to `list`."""
-      list_to_return = []
-      while True:
-        try:
-          list_to_return.append(queue_to_convert.get(block=False))
-        except Queue.Empty:
-          break
-      return list_to_return
-
-    if return_std_stream:
-      return tuple(
-          queue_to_list(multi_process_lib.get_user_data()[queue_name])
-          for queue_name in
-          [_AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE])
+  @contextlib.contextmanager
+  def _runtime_mode(self):
+    if self._executing_eagerly:
+      with context.eager_mode():
+        yield
     else:
-      return queue_to_list(
-          multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE])
+      with context.graph_mode():
+        yield
 
-  def _add_return_data(self, data):
-    """Add return data that will be returned by `multi_process_runner.run()`.
+  def _finish_process(self, func_status, return_value, stdout_collector,
+                      stderr_collector):
+    """Adds data to queues before program exits."""
+    # Clear the alarm.
+    signal.alarm(0)
+    if return_value is not None:
+      self._add_return_data(return_value)
+    if self._capture_std_stream:
+      # If stdout and stderr are to be collected, add them to std stream
+      # queue.
+      self._add_std_stream_data_flattened(stdout_collector.log)
+      self._add_std_stream_data_flattened(stderr_collector.log)
+    self._get_process_status_queue().put(func_status)
 
-    The function provides a way for processes started by
-    `multi_process_runner.run()` to communicate with the original process
-    that started the sub-processes. Data passed to `_add_return_data` will
-    be available in a python Queue.Queue that is eventually returned by
-    `multi_process_runner.run()`.
+  def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
+    """The wrapper function that actually gets run in child process(es)."""
+    os.environ['TF_CONFIG'] = json.dumps({
+        'cluster': self._cluster_spec,
+        'task': {
+            'type': task_type,
+            'index': task_id,
+        }
+    })
+
+    if self._capture_std_stream:
+      # TODO(yuefengz): consider a lighter way of capturing std streams.
+      stdout_collector = _LogCollector(sys.__stdout__)
+      stderr_collector = _LogCollector(sys.__stderr__)
+      sys.stdout = stdout_collector
+      sys.stderr = stderr_collector
+    else:
+      stdout_collector = None
+      stderr_collector = None
+
+    if self._v2_enabled:
+      v2_compat.enable_v2_behavior()
+
+    return_value = None
+
+    if self._max_run_time is not None:
+      # Register an sigalarm handler to exit the process when it reaches
+      # `timeout` seconds. A program reaching `timeout` doesn't necessarily
+      # indicate an issue.
+      def handler(signum, frame):
+        del signum, frame
+        self._finish_process(_FINISH_PROPERLY_MESSAGE, None, stdout_collector,
+                             stderr_collector)
+        os._exit(0)  # pylint: disable=protected-access
+
+      signal.signal(signal.SIGALRM, handler)
+      signal.alarm(self._max_run_time)
+
+    try:
+      with self._runtime_mode():
+        return_value = self._proc_func(*arg, **kwargs)
+    except Exception:  # pylint: disable=broad-except
+      # Capture all exceptions to be reported to parent process.
+      self._finish_process(
+          _ExcInfoWrapper(sys.exc_info()), return_value, stdout_collector,
+          stderr_collector)
+
+      # Re-raise the exception in addition to reporting it to the parent
+      # process, so that even if `--test_timeout` flag is set and the
+      # error doesn't make it to be shown in parent process before bazel's
+      # timeout, the log would still show what happens in this subprocess,
+      # instead of silently suppressing the error due to early bazel
+      # timeout. Raising an error in the subprocess produces stack trace in
+      # the log, but the program continues running.
+      raise
+
+    self._finish_process(_FINISH_PROPERLY_MESSAGE, return_value,
+                         stdout_collector, stderr_collector)
+
+  def start(self):
+    """Starts processes, one for each task in `cluster_spec`."""
+    for task_type, addresses in self._cluster_spec.items():
+      for task_id, _ in enumerate(addresses):
+        p = multi_process_lib.Process(
+            target=self._proc_func_wrapper,
+            args=(task_type, task_id) + self._args,
+            kwargs=self._kwargs)
+        p.start()
+        self._processes.append(p)
+
+  def _queue_to_list(self, queue_to_convert):
+    """Convert `queue.Queue` to `list`."""
+    list_to_return = []
+    # Calling `queue.empty()` is not reliable.
+    while True:
+      try:
+        list_to_return.append(queue_to_convert.get(block=False))
+      except Queue.Empty:
+        break
+    return list_to_return
+
+  def join(self, timeout=None):
+    """Joins all the processes with timeout.
 
     Args:
-      data: data to be made available in the queue returned by
-        `multi_process_runner.run()`.
+      timeout: if set and not all processes report status within roughly
+        `timeout` seconds, a `RuntimeError` exception will be thrown.
+
+    Returns:
+      It returns a tuple. The first element is a list that stores the return
+      data added by subprocesses through `_add_return_data` or through normal
+      function return; The second element is a list of the messages streamed to
+      stdout and stderr in the subprocesses if `capture_std_stream` is True or
+      `None` otherwise.
+
+    Raises:
+      RuntimeError: if not all processes report status within `timeout` seconds.
+      Or the exception propagated from any child process.
+    """
+    if not timeout:
+      if self._max_run_time:
+        timeout = self._max_run_time + 10  # add 10 seconds grace period
+      else:
+        timeout = float('inf')
+    num_returned = 0
+    start_time = time.time()
+    while num_returned < len(self._processes):
+      while True:
+        try:
+          process_status = self._get_process_status_queue().get(timeout=10)
+          break
+        except Queue.Empty:
+          if time.time() - start_time > timeout:
+            # If none of those did, report timeout to user.
+            raise RuntimeError(
+                'One or more subprocesses timed out. Please use '
+                '`--test_arg=--logtostderr` bazel flag to inspect logs for '
+                'subprocess debugging info. Number of returned processes is '
+                '%d.' % num_returned)
+
+      num_returned += 1
+      if isinstance(process_status, _ExcInfoWrapper):
+        six.reraise(*process_status.exc_info)
+      assert process_status == _FINISH_PROPERLY_MESSAGE
+
+    self._processes = []
+
+    if self._capture_std_stream:
+      # TODO(yuefengz): we need to make sure elements match the same process in
+      # the two returned lists so as to not surprise users. Consider creating a
+      # `ReturnData` class.
+      return tuple(
+          self._queue_to_list(multi_process_lib.get_user_data()[queue_name])
+          for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE])
+    else:
+      return (self._queue_to_list(
+          multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
+
+  def _add_return_data(self, data):
+    """Adds return data that will be returned by `join`.
+
+    The function provides a way for child processes to communicate with the
+    parent process. Data passed to `_add_return_data` will be available in a
+    Python Queue.Queue that is eventually returned by `join`.
+
+    Args:
+      data: data to be made available in the queue returned by `join`.
     """
     # TODO(rchao): Incorporate the task type and id information in a data
     # wrapper that becomes what is stored in the queue so we can tell where
     # the data is from.
-    multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data)
+    multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
 
   def _add_std_stream_data_flattened(self, data):
-    std_stream_queue = multi_process_lib.get_user_data()[
-        _AvailableQueues.STD_STREAM_QUEUE]
+    # TODO(yuefengz): currently the same queue is used by multiple processes. It
+    # is difficult for users to distinguish between logs from different
+    # processes.
+    std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE]
     for d in list(data):
       std_stream_queue.put(d)
 
-  def _get_internal_queue(self):
-    return multi_process_lib.get_user_data()[_AvailableQueues.INTERNAL_QUEUE]
+  def _get_process_status_queue(self):
+    return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
+
+
+def run(proc_func,
+        cluster_spec,
+        max_run_time=None,
+        capture_std_stream=False,
+        args=None,
+        kwargs=None):  # pylint: disable=g-doc-args
+  """Runs functions in local child processes.
+
+  It is a convenience method that creates a `MultiProcessRunner` object and
+  invokes `start` and `join` method. Please see these methods for detailed
+  documentations.
+
+  Returns:
+    A tuple returned from `MultiProcessRunner.join()`.
+  """
+  runner = MultiProcessRunner(
+      proc_func,
+      cluster_spec,
+      max_run_time=max_run_time,
+      capture_std_stream=capture_std_stream,
+      args=args,
+      kwargs=kwargs)
+  runner.start()
+  return runner.join()
 
 
 def test_main():
   """Main function to be called within `__main__` of a test file."""
   with multi_process_lib.context_manager():
     test.main()
-
-
-def job_count_to_cluster_spec(job_count_dict):
-  """Convert a job count dict to cluster spec.
-
-  Args:
-    job_count_dict: Dict for task_type/count of such task type.
-        {'worker': 1, 'ps': 1} is an example of a cluster with a worker and a
-          ps.
-
-  Returns:
-    The converted cluster spec dict.
-  """
-
-  cluster_spec = {}
-  for task_type, count in job_count_dict.items():
-    cluster_spec[task_type] = [
-        'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
-        for _ in range(count)
-    ]
-  return cluster_spec
diff --git a/tensorflow/python/distribute/multi_process_runner_no_init_test.py b/tensorflow/python/distribute/multi_process_runner_no_init_test.py
index c5820271d1a..475255d5e0a 100644
--- a/tensorflow/python/distribute/multi_process_runner_no_init_test.py
+++ b/tensorflow/python/distribute/multi_process_runner_no_init_test.py
@@ -19,23 +19,22 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.distribute import multi_process_runner
-from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
+from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.eager import test
 
 
 class MultiProcessRunnerNoInitTest(test.TestCase):
 
-  def test_stdout_captured(self):
+  def test_not_calling_correct_main(self):
 
     def simple_func():
       return 'foobar'
 
-    job_count_dict = {'worker': 1}
     with self.assertRaisesRegexp(RuntimeError,
                                  '`multi_process_runner` is not initialized.'):
-      MultiProcessRunner().run(
+      multi_process_runner.run(
           simple_func,
-          multi_process_runner.job_count_to_cluster_spec(job_count_dict))
+          multi_worker_test_base.create_cluster_spec(num_workers=1))
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py
index 98ca282b7b3..4144eb6f040 100644
--- a/tensorflow/python/distribute/multi_process_runner_test.py
+++ b/tensorflow/python/distribute/multi_process_runner_test.py
@@ -20,19 +20,15 @@ from __future__ import print_function
 
 import time
 
-from absl import flags
 from six.moves import queue as Queue
 
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import multi_worker_test_base
-from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
 from tensorflow.python.eager import test
 
-flags.DEFINE_boolean(name='test_flag', default=0, help='Test flag')
 
-
-def proc_func_that_adds_task_type_in_return_data(test_obj):
-  test_obj.assertTrue(flags.FLAGS.test_flag == 3)
+def proc_func_that_adds_task_type_in_return_data(test_obj, val):
+  test_obj.assertEqual(val, 3)
   return multi_worker_test_base.get_task_type()
 
 
@@ -55,16 +51,13 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
 class MultiProcessRunnerTest(test.TestCase):
 
   def test_multi_process_runner(self):
-    job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2}
-    proc_flags = {
-        'test_flag': 3,
-    }
-    returned_data = MultiProcessRunner().run(
+    returned_data, _ = multi_process_runner.run(
         proc_func_that_adds_task_type_in_return_data,
-        multi_process_runner.job_count_to_cluster_spec(job_count_dict),
-        proc_flags=proc_flags,
-        args=(self,))
+        multi_worker_test_base.create_cluster_spec(
+            num_workers=2, num_ps=3, has_eval=1),
+        args=(self, 3))
 
+    job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
     for data in returned_data:
       job_count_dict[data] -= 1
 
@@ -73,31 +66,29 @@ class MultiProcessRunnerTest(test.TestCase):
     self.assertEqual(job_count_dict['evaluator'], 0)
 
   def test_multi_process_runner_error_propagates_from_subprocesses(self):
-    job_count_dict = {'worker': 1, 'ps': 1}
+    runner = multi_process_runner.MultiProcessRunner(
+        proc_func_that_errors,
+        multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
+        max_run_time=20)
+    runner.start()
     with self.assertRaisesRegexp(ValueError, 'This is an error.'):
-      MultiProcessRunner().run(
-          proc_func_that_errors,
-          multi_process_runner.job_count_to_cluster_spec(job_count_dict),
-          timeout=20)
+      runner.join()
 
   def test_multi_process_runner_queue_emptied_between_runs(self):
-    job_count_dict = {'worker': 2}
-    cluster_spec = multi_process_runner.job_count_to_cluster_spec(
-        job_count_dict)
-    returned_data = MultiProcessRunner().run(
+    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
+    returned_data, _ = multi_process_runner.run(
         proc_func_that_adds_simple_return_data, cluster_spec)
     self.assertTrue(returned_data)
     self.assertEqual(returned_data[0], 'dummy_data')
     self.assertEqual(returned_data[1], 'dummy_data')
-    returned_data = MultiProcessRunner().run(proc_func_that_does_nothing,
-                                             cluster_spec)
+    returned_data, _ = multi_process_runner.run(proc_func_that_does_nothing,
+                                                cluster_spec)
     self.assertFalse(returned_data)
 
   def test_multi_process_runner_args_passed_correctly(self):
-    job_count_dict = {'worker': 1}
-    returned_data = MultiProcessRunner().run(
+    returned_data, _ = multi_process_runner.run(
         proc_func_that_return_args_and_kwargs,
-        multi_process_runner.job_count_to_cluster_spec(job_count_dict),
+        multi_worker_test_base.create_cluster_spec(num_workers=1),
         args=('a', 'b'),
         kwargs={'c_k': 'c_v'})
     self.assertEqual(returned_data[0][0], 'a')
@@ -110,11 +101,10 @@ class MultiProcessRunnerTest(test.TestCase):
       print('This is something printed.')
       return 'This is returned data.'
 
-    job_count_dict = {'worker': 2}
-    returned_data, std_stream_data = MultiProcessRunner().run(
+    returned_data, std_stream_data = multi_process_runner.run(
         simple_print_func,
-        multi_process_runner.job_count_to_cluster_spec(job_count_dict),
-        return_std_stream=True)
+        multi_worker_test_base.create_cluster_spec(num_workers=2),
+        capture_std_stream=True)
     num_string_std_stream = len(
         [d for d in std_stream_data if d == 'This is something printed.'])
     num_string_returned_data = len(
@@ -123,34 +113,32 @@ class MultiProcessRunnerTest(test.TestCase):
     self.assertEqual(num_string_returned_data, 2)
 
   def test_process_that_exits(self):
-
-    mpr = MultiProcessRunner()
-
     def func_to_exit_in_10_sec():
       time.sleep(5)
       mpr._add_return_data('foo')
       time.sleep(20)
       mpr._add_return_data('bar')
 
-    job_count_dict = {'worker': 1}
-    returned_data = mpr.run(
+    mpr = multi_process_runner.MultiProcessRunner(
         func_to_exit_in_10_sec,
-        multi_process_runner.job_count_to_cluster_spec(job_count_dict),
-        time_to_exit=10)
+        multi_worker_test_base.create_cluster_spec(num_workers=1),
+        max_run_time=10)
+
+    mpr.start()
+    returned_data, _ = mpr.join()
     self.assertLen(returned_data, 1)
 
   def test_signal_doesnt_fire_after_process_exits(self):
-    job_count_dict = {'worker': 1}
-    mpr = MultiProcessRunner()
-    mpr.run(
+    mpr = multi_process_runner.MultiProcessRunner(
         proc_func_that_does_nothing,
-        multi_process_runner.job_count_to_cluster_spec(job_count_dict),
-        time_to_exit=10)
-    time.sleep(15)
+        multi_worker_test_base.create_cluster_spec(num_workers=1),
+        max_run_time=10)
+    mpr.start()
+    mpr.join()
     with self.assertRaisesRegexp(Queue.Empty, ''):
       # If the signal was fired, another message would be added to internal
       # queue, so verifying it's empty.
-      mpr._get_internal_queue().get(block=False)
+      mpr._get_process_status_queue().get(block=False)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py
index 19790a0d69f..9668bc23351 100644
--- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py
+++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py
@@ -78,7 +78,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
 
     # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
     with multi_process_runner_util.try_run_and_except_connection_error(self):
-      multi_process_runner.MultiProcessRunner().run(
+      multi_process_runner.run(
           worker_fn,
           cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))