PSv2: Check that there is no more than one chief, and at least one ps/worker. Combine the validation logic with multi_worker_util.
PiperOrigin-RevId: 341740027 Change-Id: I7e3125f8eaefb12c96f37b7fa3a54afbfc1e4334
This commit is contained in:
parent
229a8af422
commit
faf44b5391
tensorflow/python/distribute
@ -46,13 +46,21 @@ def normalize_cluster_spec(cluster_spec):
|
||||
return cluster_spec
|
||||
|
||||
|
||||
# TODO(yuefengz): add more validations.
|
||||
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
def task_count(cluster_spec, task_type):
|
||||
try:
|
||||
return cluster_spec.num_tasks(task_type)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
|
||||
def _validate_cluster_spec(cluster_spec,
|
||||
task_type,
|
||||
task_id):
|
||||
"""Validates `cluster_spec`.
|
||||
|
||||
It checks:
|
||||
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
|
||||
1) task type is one of "chief", "worker" or "evaluator".
|
||||
1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
|
||||
(None).
|
||||
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
||||
only exception is `evaluator`. In other words, it is still a valid
|
||||
configuration when `task_type` is `evaluator` but it doesn't appear in
|
||||
@ -65,31 +73,38 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
Args:
|
||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||
task_type: string indicating the type of the task.
|
||||
task_id: task_id: the id of the `task_type` in this cluster.
|
||||
Throws:
|
||||
task_id: the id of the `task_type` in this cluster.
|
||||
|
||||
Raises:
|
||||
ValueError: if `cluster_spec` fails any check.
|
||||
"""
|
||||
if cluster_spec is None or task_type is None or task_id is None:
|
||||
raise ValueError(
|
||||
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
|
||||
allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
|
||||
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
||||
raise ValueError(
|
||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec)
|
||||
|
||||
if task_type and task_type not in cluster_spec and task_type != "evaluator":
|
||||
if any([job not in allowed_task_types for job in cluster_spec.jobs]):
|
||||
raise ValueError("Disallowed task type found in cluster spec. Allowed "
|
||||
"types are {} and the cluster spec is {}.".format(
|
||||
allowed_task_types, cluster_spec))
|
||||
|
||||
if task_type not in allowed_task_types:
|
||||
raise ValueError(
|
||||
"Unrecognized task_type: {}, valid task types are: {}".format(
|
||||
task_type, allowed_task_types))
|
||||
|
||||
if (task_type and task_type not in cluster_spec.jobs and
|
||||
task_type != "evaluator"):
|
||||
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||
|
||||
if len(cluster_spec.get("chief", [])) > 1:
|
||||
if task_count(cluster_spec, "chief") > 1:
|
||||
raise ValueError("There must be at most one 'chief' job.")
|
||||
|
||||
if len(cluster_spec.get("evaluator", [])) > 1:
|
||||
if task_count(cluster_spec, "evaluator") > 1:
|
||||
raise ValueError("There must be at most one 'evaluator' job.")
|
||||
|
||||
# The `evaluator` job is allowed to be missing in `cluster_spec`.
|
||||
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
|
||||
if task_type in cluster_spec.jobs and task_id >= task_count(
|
||||
cluster_spec, task_type):
|
||||
raise ValueError(
|
||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||
|
||||
|
@ -26,6 +26,7 @@ import os
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.eager import remote
|
||||
@ -486,22 +487,19 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
|
||||
raise NotImplementedError("Multi-gpu is not supported yet.")
|
||||
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
|
||||
# The following checks if the task types are allowed (chief, ps, worker).
|
||||
disallowed_task_type_error_str = (
|
||||
"Disallowed task type found in "
|
||||
"`tf.distribute.cluster_resolver.ClusterResolver` provided to "
|
||||
"`tf.distribute.experimental.ParameterServerStrategy`. Allowed types "
|
||||
"are {},".format(ALLOWED_TASK_TYPES))
|
||||
if any([
|
||||
job not in ALLOWED_TASK_TYPES
|
||||
for job in cluster_resolver.cluster_spec().jobs
|
||||
]):
|
||||
raise ValueError("{} and the cluster spec is {}.".format(
|
||||
disallowed_task_type_error_str, cluster_resolver.cluster_spec()))
|
||||
if (cluster_resolver.task_type and
|
||||
cluster_resolver.task_type not in ALLOWED_TASK_TYPES):
|
||||
raise ValueError("{} and current task type is {}.".format(
|
||||
disallowed_task_type_error_str, cluster_resolver.task_type))
|
||||
multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access
|
||||
cluster_spec,
|
||||
cluster_resolver.task_type,
|
||||
cluster_resolver.task_id)
|
||||
|
||||
if multi_worker_util.task_count(cluster_spec, "ps") < 1:
|
||||
raise ValueError("There must be at least one ps.")
|
||||
|
||||
if multi_worker_util.task_count(cluster_spec, "worker") < 1:
|
||||
raise ValueError("There must be at least one worker.")
|
||||
|
||||
|
||||
class ParameterServerStrategyV2Extended(
|
||||
|
@ -412,7 +412,47 @@ class ClusterTypeNameTest(test.TestCase):
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
|
||||
with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
|
||||
with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||
|
||||
def testMoreThanOneChief(self):
|
||||
cluster_def = multi_worker_test_base._create_cluster(
|
||||
num_workers=1, num_ps=1)
|
||||
chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
|
||||
cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def),
|
||||
rpc_layer="grpc",
|
||||
task_type="chief",
|
||||
task_id=1)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"There must be at most one 'chief' job."):
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||
|
||||
def testLessThanOneWorker(self):
|
||||
cluster_def = multi_worker_test_base._create_cluster(
|
||||
num_workers=0, num_ps=1)
|
||||
cluster_def["chief"] = [
|
||||
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"There must be at least one worker."):
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||
|
||||
def testLessThanOnePs(self):
|
||||
cluster_def = multi_worker_test_base._create_cluster(
|
||||
num_workers=1, num_ps=0)
|
||||
cluster_def["chief"] = [
|
||||
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def),
|
||||
rpc_layer="grpc",
|
||||
task_type="worker",
|
||||
task_id=0)
|
||||
with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user