267 lines
9.6 KiB
Python
267 lines
9.6 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for multi-worker distribution strategies."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.core.protobuf import cluster_pb2
|
|
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
|
from tensorflow.python.training import server_lib
|
|
|
|
|
|
def normalize_cluster_spec(cluster_spec):
|
|
"""Makes `cluster_spec` into a `ClusterSpec` object.
|
|
|
|
Args:
|
|
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
|
|
cluster configurations.
|
|
|
|
Returns:
|
|
a `ClusterSpec` object.
|
|
|
|
Raises:
|
|
ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
|
|
`ClusterDef`.
|
|
"""
|
|
if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
|
|
return server_lib.ClusterSpec(cluster_spec)
|
|
elif not isinstance(cluster_spec, server_lib.ClusterSpec):
|
|
raise ValueError(
|
|
"`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
|
|
"`tf.train.ClusterDef` object")
|
|
return cluster_spec
|
|
|
|
|
|
# TODO(yuefengz): add more validations.
|
|
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
|
"""Validates `cluster_spec`.
|
|
|
|
It checks:
|
|
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
|
|
1) task type is one of "chief", "worker" or "evaluator".
|
|
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
|
only exception is `evaluator`. In other words, it is still a valid
|
|
configuration when `task_type` is `evaluator` but it doesn't appear in
|
|
`cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
|
|
3) whether there is at most one "chief" job.
|
|
4) whether there is at most one "evaluator" job.
|
|
5) whether the `task_id` is smaller than the number of tasks for that
|
|
particular `task_type`.
|
|
|
|
Args:
|
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
|
task_type: string indicating the type of the task.
|
|
task_id: task_id: the id of the `task_type` in this cluster.
|
|
Throws:
|
|
ValueError: if `cluster_spec` fails any check.
|
|
"""
|
|
if cluster_spec is None or task_type is None or task_id is None:
|
|
raise ValueError(
|
|
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
|
|
|
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
|
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
|
raise ValueError(
|
|
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
|
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
|
|
|
if task_type and task_type not in cluster_spec and task_type != "evaluator":
|
|
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
|
|
|
if len(cluster_spec.get("chief", [])) > 1:
|
|
raise ValueError("There must be at most one 'chief' job.")
|
|
|
|
if len(cluster_spec.get("evaluator", [])) > 1:
|
|
raise ValueError("There must be at most one 'evaluator' job.")
|
|
|
|
# The `evaluator` job is allowed to be missing in `cluster_spec`.
|
|
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
|
|
raise ValueError(
|
|
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
|
|
|
|
|
def is_chief(cluster_spec=None, task_type=None, task_id=None):
|
|
"""Returns whether the given task is chief in the cluster.
|
|
|
|
Since there is at most one evaluator and the evaluator itself should be
|
|
independent of the training cluster, the evaluator job is also a chief job on
|
|
its own.
|
|
|
|
If this is currently running under a `_WorkerContext` of distribute
|
|
coordinator, the arguments can be omitted as the result is already available.
|
|
|
|
Args:
|
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
|
cluster configurations.
|
|
task_type: the task type in the cluster.
|
|
task_id: the task id in the cluster.
|
|
|
|
Returns:
|
|
a boolean indicating whether the given task is chief.
|
|
|
|
Raises:
|
|
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
|
|
the maximum id of the `task_type`.
|
|
"""
|
|
if has_worker_context():
|
|
# If a worker context exists, use the value provided by it.
|
|
return dc_context.get_current_worker_context().is_chief
|
|
|
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
|
|
|
if task_type == "chief" or task_type == "evaluator":
|
|
return True
|
|
|
|
# If chief not in the cluster_spec, use the first worker as chief. This is
|
|
# common in CollectiveAllReduceStrategy.
|
|
if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
|
|
return True
|
|
return False
|
|
|
|
|
|
def collective_leader(cluster_spec, task_type, task_id):
|
|
"""Return the job name for the leader of for collective ops.
|
|
|
|
Args:
|
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
|
cluster configurations.
|
|
task_type: the task type in the cluster.
|
|
task_id: the task id in the cluster.
|
|
|
|
Returns:
|
|
a string indicating the leader job name or empty string if no need to set
|
|
leader job.
|
|
"""
|
|
cluster_spec = normalize_cluster_spec(cluster_spec)
|
|
|
|
# No need to set collective leader for local.
|
|
if not cluster_spec.as_dict():
|
|
return ""
|
|
|
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
|
|
|
# Only one evaluator, so no need to set collective leader.
|
|
if task_type == "evaluator":
|
|
return ""
|
|
|
|
# Use chief if chief is in the cluster.
|
|
if "chief" in cluster_spec.jobs:
|
|
return "/job:chief/replica:0/task:0"
|
|
|
|
# Use worker 0 if no chief job.
|
|
assert "worker" in cluster_spec.jobs
|
|
return "/job:worker/replica:0/task:0"
|
|
|
|
|
|
def worker_count(cluster_spec, task_type):
|
|
"""Returns the number of workers in the cluster."""
|
|
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
|
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
|
|
|
# Other jobs such as "ps" shouldn't call this function.
|
|
if task_type not in ["chief", "worker", "evaluator"]:
|
|
raise ValueError("Unexpected `task_type` %r" % task_type)
|
|
|
|
if task_type == "evaluator":
|
|
# The "evaluator" is in its own cluster or its own partition of a cluster.
|
|
# So we don't have to count "chief" or "worker" if the current task is an
|
|
# "evaluator".
|
|
return len(cluster_spec["evaluator"])
|
|
else:
|
|
# In the non-evaluator case, we return the total number of "chief" and
|
|
# "worker" tasks as the "chief" is also a worker.
|
|
return (len(cluster_spec.get("chief", [])) + len(
|
|
cluster_spec.get("worker", [])))
|
|
|
|
|
|
def id_in_cluster(cluster_spec, task_type, task_id):
|
|
"""Returns a unique id for the task in the `task_type`'s cluster.
|
|
|
|
It returns an id ranging from [0, `worker_count(task_type, task_id)`).
|
|
|
|
Note: this function assumes that "evaluate" job is in its own cluster or its
|
|
own partition of a cluster.
|
|
|
|
Args:
|
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
|
task_type: string indicating the type of the task.
|
|
task_id: the id of the `task_type` in this cluster.
|
|
|
|
Returns:
|
|
an int indicating the unique id.
|
|
|
|
Throws:
|
|
ValueError: if `task_type` is not "chief", "worker" or "evaluator".
|
|
"""
|
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
|
|
|
# The "chief" job has always id 0 and there is at most one and "worker" jobs
|
|
# come after it.
|
|
if task_type == "chief":
|
|
return 0
|
|
|
|
if task_type == "worker":
|
|
return task_id + len(cluster_spec.get("chief", []))
|
|
|
|
# The "evaluator" is in its own cluster or its own partition of a cluster.
|
|
if task_type == "evaluator":
|
|
return task_id
|
|
|
|
# We currently don't assign ids to other tasks.
|
|
raise ValueError("There is no id for task_type %r" % task_type)
|
|
|
|
|
|
def should_save_checkpoint():
|
|
"""Returns whether the current worker should save checkpoints.
|
|
|
|
In multi-worker training, if saving checkpoint is requested by user, or needed
|
|
for fault-tolerance, the cluster should save checkpoint but not necessarily
|
|
every worker in the cluster should.
|
|
|
|
TODO(rchao): Consider generalizing this util to be `should_save_file` as there
|
|
can be other files to save such as summary.
|
|
|
|
Returns:
|
|
Whether this particular worker in the cluster should save checkpoints.
|
|
"""
|
|
return dc_context.get_current_worker_context().should_checkpoint
|
|
|
|
|
|
def should_load_checkpoint():
|
|
"""Returns whether the current worker should load checkpoints.
|
|
|
|
In multi-worker training, if loading checkpoint is requested by user, or
|
|
needed for fault-tolerance, the cluster should load checkpoint but not
|
|
necessarily every worker in the cluster should.
|
|
|
|
Returns:
|
|
Whether this particular worker in the cluster should load checkpoints.
|
|
"""
|
|
return dc_context.get_current_worker_context().experimental_should_init
|
|
|
|
|
|
def wait_for_other_workers():
|
|
"""Waits for other workers to reach the same call to this method."""
|
|
return dc_context.get_current_worker_context().wait_for_other_workers()
|
|
|
|
|
|
def has_worker_context():
|
|
"""Returns whether a worker context has been entered."""
|
|
return dc_context.get_current_worker_context() is not None
|