fix some linter errors for slurm_cluster_resolver.
PiperOrigin-RevId: 313873815 Change-Id: I15ae65bb27af2ee9d60b3629c91c0234fbc8943f
This commit is contained in:
parent
bacd18849e
commit
3fbd5ac42e
@ -19,8 +19,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import re
|
import re
|
||||||
|
import subprocess
|
||||||
|
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||||
@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
|
|
||||||
def expand_hostlist(hostlist):
|
def expand_hostlist(hostlist):
|
||||||
"""Create a list of hosts out of a SLURM hostlist
|
"""Create a list of hosts out of a SLURM hostlist.
|
||||||
|
|
||||||
The order of nodes is preserved and no deduplication is done
|
The order of nodes is preserved and no deduplication is done
|
||||||
Input: 'n[1-2],m5,o[3-4,6,7-9]')
|
Input: 'n[1-2],m5,o[3-4,6,7-9]')
|
||||||
@ -37,7 +37,7 @@ def expand_hostlist(hostlist):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def split_hostlist(hostlist):
|
def split_hostlist(hostlist):
|
||||||
"""Split hostlist at commas outside of range expressions ('[3-5]')"""
|
"""Split hostlist at commas outside of range expressions ('[3-5]')."""
|
||||||
in_brackets = False
|
in_brackets = False
|
||||||
cur_host = ''
|
cur_host = ''
|
||||||
for c in hostlist:
|
for c in hostlist:
|
||||||
@ -57,7 +57,7 @@ def expand_hostlist(hostlist):
|
|||||||
yield cur_host
|
yield cur_host
|
||||||
|
|
||||||
def expand_range_expression(range_exp):
|
def expand_range_expression(range_exp):
|
||||||
"""Expand a range expression like '3-5' to values 3,4,5"""
|
"""Expand a range expression like '3-5' to values 3,4,5."""
|
||||||
for part in range_exp.split(','):
|
for part in range_exp.split(','):
|
||||||
sub_range = part.split('-')
|
sub_range = part.split('-')
|
||||||
if len(sub_range) == 1:
|
if len(sub_range) == 1:
|
||||||
@ -87,7 +87,7 @@ def expand_hostlist(hostlist):
|
|||||||
|
|
||||||
|
|
||||||
def expand_tasks_per_node(tasks_per_node):
|
def expand_tasks_per_node(tasks_per_node):
|
||||||
"""Expand the tasks per node expression from SLURM
|
"""Expands the tasks per node expression from SLURM.
|
||||||
|
|
||||||
The order is preserved so it can be matched to the hostlist
|
The order is preserved so it can be matched to the hostlist
|
||||||
Input: '3(x2),2,1'
|
Input: '3(x2),2,1'
|
||||||
@ -108,7 +108,7 @@ def expand_tasks_per_node(tasks_per_node):
|
|||||||
|
|
||||||
|
|
||||||
def _get_slurm_var(name):
|
def _get_slurm_var(name):
|
||||||
"""Get the SLURM variable from the environment
|
"""Gets the SLURM variable from the environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the step variable
|
name: Name of the step variable
|
||||||
@ -126,8 +126,8 @@ def _get_slurm_var(name):
|
|||||||
'Not running inside a SLURM step?' % name)
|
'Not running inside a SLURM step?' % name)
|
||||||
|
|
||||||
|
|
||||||
def get_num_slurm_tasks():
|
def _get_num_slurm_tasks():
|
||||||
"""Return the number of SLURM tasks of the current job step
|
"""Returns the number of SLURM tasks of the current job step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tasks as an int
|
The number of tasks as an int
|
||||||
@ -136,7 +136,7 @@ def get_num_slurm_tasks():
|
|||||||
|
|
||||||
|
|
||||||
def _get_num_nvidia_gpus():
|
def _get_num_nvidia_gpus():
|
||||||
"""Get the number of NVIDIA GPUs by using CUDA_VISIBLE_DEVICES and nvidia-smi
|
"""Gets the number of NVIDIA GPUs by using CUDA_VISIBLE_DEVICES and nvidia-smi.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of GPUs available on the node
|
Number of GPUs available on the node
|
||||||
@ -157,9 +157,9 @@ def _get_num_nvidia_gpus():
|
|||||||
|
|
||||||
|
|
||||||
def get_num_gpus():
|
def get_num_gpus():
|
||||||
"""Return the number of GPUs visible on the current node
|
"""Returns the number of GPUs visible on the current node.
|
||||||
|
|
||||||
Currently only implemented for NVIDIA GPUs
|
Currently only implemented for NVIDIA GPUs.
|
||||||
"""
|
"""
|
||||||
return _get_num_nvidia_gpus()
|
return _get_num_nvidia_gpus()
|
||||||
|
|
||||||
@ -176,7 +176,6 @@ class SlurmClusterResolver(ClusterResolver):
|
|||||||
used for distributed TensorFlow.
|
used for distributed TensorFlow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
jobs=None,
|
jobs=None,
|
||||||
port_base=8888,
|
port_base=8888,
|
||||||
@ -276,19 +275,19 @@ class SlurmClusterResolver(ClusterResolver):
|
|||||||
sum(self._jobs.values()), num_tasks))
|
sum(self._jobs.values()), num_tasks))
|
||||||
|
|
||||||
def _resolve_own_rank(self):
|
def _resolve_own_rank(self):
|
||||||
"""Return the rank of the current task in range [0, num_tasks)"""
|
"""Returns the rank of the current task in range [0, num_tasks)."""
|
||||||
return int(_get_slurm_var('PROCID'))
|
return int(_get_slurm_var('PROCID'))
|
||||||
|
|
||||||
def _resolve_num_tasks(self):
|
def _resolve_num_tasks(self):
|
||||||
"""Return the number of tasks for the current job step"""
|
"""Returns the number of tasks for the current job step."""
|
||||||
return get_num_slurm_tasks()
|
return _get_num_slurm_tasks()
|
||||||
|
|
||||||
def _resolve_hostlist(self):
|
def _resolve_hostlist(self):
|
||||||
"""Return a list of hostnames for nodes running the current job step"""
|
"""Returns a list of hostnames for nodes running the current job step."""
|
||||||
return expand_hostlist(_get_slurm_var('STEP_NODELIST'))
|
return expand_hostlist(_get_slurm_var('STEP_NODELIST'))
|
||||||
|
|
||||||
def _resolve_task_configuration(self):
|
def _resolve_task_configuration(self):
|
||||||
"""Create a mapping of hostnames to the number of tasks allocated on it
|
"""Creates a mapping of hostnames to the number of tasks allocated on it.
|
||||||
|
|
||||||
Reads the SLURM environment to determine the nodes involved in the current
|
Reads the SLURM environment to determine the nodes involved in the current
|
||||||
job step and number of tasks running on each node.
|
job step and number of tasks running on each node.
|
||||||
@ -352,7 +351,7 @@ class SlurmClusterResolver(ClusterResolver):
|
|||||||
|
|
||||||
cluster_rank_offset_start = cluster_rank_offset_end
|
cluster_rank_offset_start = cluster_rank_offset_end
|
||||||
|
|
||||||
if self._auto_set_gpu is True:
|
if self._auto_set_gpu:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
|
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
|
||||||
|
|
||||||
return ClusterSpec(self._cluster_allocation)
|
return ClusterSpec(self._cluster_allocation)
|
||||||
|
Loading…
Reference in New Issue
Block a user