fix some linter errors for slurm_cluster_resolver.
PiperOrigin-RevId: 313873815 Change-Id: I15ae65bb27af2ee9d60b3629c91c0234fbc8943f
This commit is contained in:
parent
bacd18849e
commit
3fbd5ac42e
@ -19,8 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||
@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def expand_hostlist(hostlist):
|
||||
"""Create a list of hosts out of a SLURM hostlist
|
||||
"""Create a list of hosts out of a SLURM hostlist.
|
||||
|
||||
The order of nodes is preserved and no deduplication is done
|
||||
Input: 'n[1-2],m5,o[3-4,6,7-9]')
|
||||
@ -37,7 +37,7 @@ def expand_hostlist(hostlist):
|
||||
"""
|
||||
|
||||
def split_hostlist(hostlist):
|
||||
"""Split hostlist at commas outside of range expressions ('[3-5]')"""
|
||||
"""Split hostlist at commas outside of range expressions ('[3-5]')."""
|
||||
in_brackets = False
|
||||
cur_host = ''
|
||||
for c in hostlist:
|
||||
@ -57,7 +57,7 @@ def expand_hostlist(hostlist):
|
||||
yield cur_host
|
||||
|
||||
def expand_range_expression(range_exp):
|
||||
"""Expand a range expression like '3-5' to values 3,4,5"""
|
||||
"""Expand a range expression like '3-5' to values 3,4,5."""
|
||||
for part in range_exp.split(','):
|
||||
sub_range = part.split('-')
|
||||
if len(sub_range) == 1:
|
||||
@ -87,7 +87,7 @@ def expand_hostlist(hostlist):
|
||||
|
||||
|
||||
def expand_tasks_per_node(tasks_per_node):
|
||||
"""Expand the tasks per node expression from SLURM
|
||||
"""Expands the tasks per node expression from SLURM.
|
||||
|
||||
The order is preserved so it can be matched to the hostlist
|
||||
Input: '3(x2),2,1'
|
||||
@ -108,7 +108,7 @@ def expand_tasks_per_node(tasks_per_node):
|
||||
|
||||
|
||||
def _get_slurm_var(name):
|
||||
"""Get the SLURM variable from the environment
|
||||
"""Gets the SLURM variable from the environment.
|
||||
|
||||
Args:
|
||||
name: Name of the step variable
|
||||
@ -126,8 +126,8 @@ def _get_slurm_var(name):
|
||||
'Not running inside a SLURM step?' % name)
|
||||
|
||||
|
||||
def get_num_slurm_tasks():
|
||||
"""Return the number of SLURM tasks of the current job step
|
||||
def _get_num_slurm_tasks():
|
||||
"""Returns the number of SLURM tasks of the current job step.
|
||||
|
||||
Returns:
|
||||
The number of tasks as an int
|
||||
@ -136,7 +136,7 @@ def get_num_slurm_tasks():
|
||||
|
||||
|
||||
def _get_num_nvidia_gpus():
|
||||
"""Get the number of NVIDIA GPUs by using CUDA_VISIBLE_DEVICES and nvidia-smi
|
||||
"""Gets the number of NVIDIA GPUs by using CUDA_VISIBLE_DEVICES and nvidia-smi.
|
||||
|
||||
Returns:
|
||||
Number of GPUs available on the node
|
||||
@ -157,9 +157,9 @@ def _get_num_nvidia_gpus():
|
||||
|
||||
|
||||
def get_num_gpus():
|
||||
"""Return the number of GPUs visible on the current node
|
||||
"""Returns the number of GPUs visible on the current node.
|
||||
|
||||
Currently only implemented for NVIDIA GPUs
|
||||
Currently only implemented for NVIDIA GPUs.
|
||||
"""
|
||||
return _get_num_nvidia_gpus()
|
||||
|
||||
@ -176,7 +176,6 @@ class SlurmClusterResolver(ClusterResolver):
|
||||
used for distributed TensorFlow.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self,
|
||||
jobs=None,
|
||||
port_base=8888,
|
||||
@ -276,19 +275,19 @@ class SlurmClusterResolver(ClusterResolver):
|
||||
sum(self._jobs.values()), num_tasks))
|
||||
|
||||
def _resolve_own_rank(self):
|
||||
"""Return the rank of the current task in range [0, num_tasks)"""
|
||||
"""Returns the rank of the current task in range [0, num_tasks)."""
|
||||
return int(_get_slurm_var('PROCID'))
|
||||
|
||||
def _resolve_num_tasks(self):
|
||||
"""Return the number of tasks for the current job step"""
|
||||
return get_num_slurm_tasks()
|
||||
"""Returns the number of tasks for the current job step."""
|
||||
return _get_num_slurm_tasks()
|
||||
|
||||
def _resolve_hostlist(self):
|
||||
"""Return a list of hostnames for nodes running the current job step"""
|
||||
"""Returns a list of hostnames for nodes running the current job step."""
|
||||
return expand_hostlist(_get_slurm_var('STEP_NODELIST'))
|
||||
|
||||
def _resolve_task_configuration(self):
|
||||
"""Create a mapping of hostnames to the number of tasks allocated on it
|
||||
"""Creates a mapping of hostnames to the number of tasks allocated on it.
|
||||
|
||||
Reads the SLURM environment to determine the nodes involved in the current
|
||||
job step and number of tasks running on each node.
|
||||
@ -352,7 +351,7 @@ class SlurmClusterResolver(ClusterResolver):
|
||||
|
||||
cluster_rank_offset_start = cluster_rank_offset_end
|
||||
|
||||
if self._auto_set_gpu is True:
|
||||
if self._auto_set_gpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
|
||||
|
||||
return ClusterSpec(self._cluster_allocation)
|
||||
|
Loading…
Reference in New Issue
Block a user