MultiProcessRunner: provide APIs for getting subprocess' ID.

PiperOrigin-RevId: 314436570
Change-Id: Ic2b0ccfcf660a731ce6739fb5af283d2acbb7917
This commit is contained in:
Rick Chao 2020-06-02 17:27:48 -07:00 committed by TensorFlower Gardener
parent 81e3084012
commit e06251b493

View File

@ -377,6 +377,28 @@ class MultiProcessRunner(object):
break
return list_to_return
def get_process_id(self, task_type, task_id):
"""Returns the subprocess id given the task type and task id."""
if not hasattr(self, '_pid_dict'):
self._pid_dict = {}
subprocess_infos = []
while True:
try:
subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False)
subprocess_infos.append(subprocess_info)
except Queue.Empty:
break
for subprocess_info in subprocess_infos:
self._pid_dict[(subprocess_info.task_type,
subprocess_info.task_id)] = subprocess_info.pid
for subprocess_info in subprocess_infos:
_resource(SUBPROCESS_INFO_QUEUE).put(subprocess_info)
return self._pid_dict.get((task_type, task_id), None)
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
"""Joins all the processes with timeout.