diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index e258d98fb8c..3e28cf44072 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -377,6 +377,28 @@ class MultiProcessRunner(object): break return list_to_return + def get_process_id(self, task_type, task_id): + """Returns the subprocess id given the task type and task id.""" + if not hasattr(self, '_pid_dict'): + self._pid_dict = {} + subprocess_infos = [] + + while True: + try: + subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) + subprocess_infos.append(subprocess_info) + except Queue.Empty: + break + + for subprocess_info in subprocess_infos: + self._pid_dict[(subprocess_info.task_type, + subprocess_info.task_id)] = subprocess_info.pid + + for subprocess_info in subprocess_infos: + _resource(SUBPROCESS_INFO_QUEUE).put(subprocess_info) + + return self._pid_dict.get((task_type, task_id), None) + def join(self, timeout=_DEFAULT_TIMEOUT_SEC): """Joins all the processes with timeout.