diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index f0f3766afe1..74d80b63e12 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1794,7 +1794,7 @@ py_test(
     name = "multi_process_runner_test",
     srcs = ["multi_process_runner_test.py"],
     python_version = "PY3",
-    tags = ["notsan"],  # TODO(b/158874970)
+    shard_count = 12,
     deps = [
         ":multi_process_runner",
         ":multi_worker_test_base",
diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py
index af527b67b4b..cb460c8fff5 100644
--- a/tensorflow/python/distribute/multi_process_runner.py
+++ b/tensorflow/python/distribute/multi_process_runner.py
@@ -423,6 +423,18 @@ class MultiProcessRunner(object):
   def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
     """Joins all the processes with timeout.
 
+    If any of the subprocesses does not exit approximately after `timeout`
+    seconds has passed after `join` call, this raises a
+    `SubprocessTimeoutError`.
+
+    Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
+    log the stack traces of the subprocesses when they exit. However, this
+    results in timeout when the test runs with tsan (thread sanitizer); if tsan
+    is being run on the test targets that rely on timeout to assert information,
+    `MultiProcessRunner.terminate_all()` must be called after `join()`, before
+    the test exits, so the subprocesses are terminated with SIGKILL, and data
+    race is removed.
+
     Args:
       timeout: if set and not all processes report status within roughly
         `timeout` seconds, a `SubprocessTimeoutError` exception will be raised.
diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py
index 6194ac527d5..529d7fd91a5 100644
--- a/tensorflow/python/distribute/multi_process_runner_test.py
+++ b/tensorflow/python/distribute/multi_process_runner_test.py
@@ -124,24 +124,6 @@ class MultiProcessRunnerTest(test.TestCase):
                   std_stream_results)
     self.assertIn('This is returned data.', return_value)
 
-  def test_process_that_exits(self):
-
-    def func_to_exit_in_25_sec():
-      logging.error('foo')
-      time.sleep(100)
-      logging.error('bar')
-
-    mpr = multi_process_runner.MultiProcessRunner(
-        func_to_exit_in_25_sec,
-        multi_worker_test_base.create_cluster_spec(num_workers=1),
-        list_stdout=True,
-        max_run_time=25)
-
-    mpr.start()
-    stdout = mpr.join().stdout
-    self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
-    self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
-
   def test_termination(self):
 
     def proc_func():
@@ -301,29 +283,21 @@ class MultiProcessRunnerTest(test.TestCase):
   def test_stdout_available_when_timeout(self):
 
     def proc_func():
-      for i in range(50):
-        logging.info('(logging) %s-%d, i: %d',
-                     multi_worker_test_base.get_task_type(), self._worker_idx(),
-                     i)
-        time.sleep(1)
+      logging.info('something printed')
+      time.sleep(10000)  # Intentionally make the test timeout.
 
     with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
-      multi_process_runner.run(
+      mpr = multi_process_runner.MultiProcessRunner(
           proc_func,
-          multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
-          list_stdout=True,
-          timeout=5)
+          multi_worker_test_base.create_cluster_spec(num_workers=1),
+          list_stdout=True)
+      mpr.start()
+      mpr.join(timeout=60)
+    mpr.terminate_all()
 
     list_to_assert = cm.exception.mpr_result.stdout
-    # We should see 5 iterations from worker and ps, however sometime on TAP
-    # due to CPU throttling and slugginess of msan/asan build, this became
-    # flaky. Therefore we allow more margin of errors to only check the first
-    # 3 iterations.
-    for job in ['worker', 'ps']:
-      for iteration in range(0, 3):
-        self.assertTrue(
-            any('(logging) {}-0, i: {}'.format(job, iteration) in line
-                for line in list_to_assert))
+    self.assertTrue(
+        any('something printed' in line for line in list_to_assert))
 
   def test_seg_fault_raises_error(self):