PSv2: Check that there is no more than one chief, and at least one ps/worker. Combine the validation logic with multi_worker_util.

PiperOrigin-RevId: 341740027
Change-Id: I7e3125f8eaefb12c96f37b7fa3a54afbfc1e4334
This commit is contained in:
Rick Chao 2020-11-10 18:31:02 -08:00 committed by TensorFlower Gardener
parent 229a8af422
commit faf44b5391
3 changed files with 87 additions and 34 deletions

View File

@ -46,13 +46,21 @@ def normalize_cluster_spec(cluster_spec):
return cluster_spec
# TODO(yuefengz): add more validations.
def _validate_cluster_spec(cluster_spec, task_type, task_id):
def task_count(cluster_spec, task_type):
try:
return cluster_spec.num_tasks(task_type)
except ValueError:
return 0
def _validate_cluster_spec(cluster_spec,
task_type,
task_id):
"""Validates `cluster_spec`.
It checks:
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
1) task type is one of "chief", "worker" or "evaluator".
1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
(None).
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
only exception is `evaluator`. In other words, it is still a valid
configuration when `task_type` is `evaluator` but it doesn't appear in
@ -65,31 +73,38 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
task_type: string indicating the type of the task.
task_id: task_id: the id of the `task_type` in this cluster.
Throws:
task_id: the id of the `task_type` in this cluster.
Raises:
ValueError: if `cluster_spec` fails any check.
"""
if cluster_spec is None or task_type is None or task_id is None:
raise ValueError(
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
if task_type not in ("chief", "worker", "evaluator", "ps"):
raise ValueError(
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
cluster_spec = normalize_cluster_spec(cluster_spec)
if task_type and task_type not in cluster_spec and task_type != "evaluator":
if any([job not in allowed_task_types for job in cluster_spec.jobs]):
raise ValueError("Disallowed task type found in cluster spec. Allowed "
"types are {} and the cluster spec is {}.".format(
allowed_task_types, cluster_spec))
if task_type not in allowed_task_types:
raise ValueError(
"Unrecognized task_type: {}, valid task types are: {}".format(
task_type, allowed_task_types))
if (task_type and task_type not in cluster_spec.jobs and
task_type != "evaluator"):
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
if len(cluster_spec.get("chief", [])) > 1:
if task_count(cluster_spec, "chief") > 1:
raise ValueError("There must be at most one 'chief' job.")
if len(cluster_spec.get("evaluator", [])) > 1:
if task_count(cluster_spec, "evaluator") > 1:
raise ValueError("There must be at most one 'evaluator' job.")
# The `evaluator` job is allowed to be missing in `cluster_spec`.
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
if task_type in cluster_spec.jobs and task_id >= task_count(
cluster_spec, task_type):
raise ValueError(
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))

View File

@ -26,6 +26,7 @@ import os
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.eager import remote
@ -486,22 +487,19 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
raise NotImplementedError("Multi-gpu is not supported yet.")
cluster_spec = cluster_resolver.cluster_spec()
# The following checks if the task types are allowed (chief, ps, worker).
disallowed_task_type_error_str = (
"Disallowed task type found in "
"`tf.distribute.cluster_resolver.ClusterResolver` provided to "
"`tf.distribute.experimental.ParameterServerStrategy`. Allowed types "
"are {},".format(ALLOWED_TASK_TYPES))
if any([
job not in ALLOWED_TASK_TYPES
for job in cluster_resolver.cluster_spec().jobs
]):
raise ValueError("{} and the cluster spec is {}.".format(
disallowed_task_type_error_str, cluster_resolver.cluster_spec()))
if (cluster_resolver.task_type and
cluster_resolver.task_type not in ALLOWED_TASK_TYPES):
raise ValueError("{} and current task type is {}.".format(
disallowed_task_type_error_str, cluster_resolver.task_type))
multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access
cluster_spec,
cluster_resolver.task_type,
cluster_resolver.task_id)
if multi_worker_util.task_count(cluster_spec, "ps") < 1:
raise ValueError("There must be at least one ps.")
if multi_worker_util.task_count(cluster_spec, "worker") < 1:
raise ValueError("There must be at least one worker.")
class ParameterServerStrategyV2Extended(

View File

@ -412,7 +412,47 @@ class ClusterTypeNameTest(test.TestCase):
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testMoreThanOneChief(self):
cluster_def = multi_worker_test_base._create_cluster(
num_workers=1, num_ps=1)
chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def),
rpc_layer="grpc",
task_type="chief",
task_id=1)
with self.assertRaisesRegexp(ValueError,
"There must be at most one 'chief' job."):
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testLessThanOneWorker(self):
cluster_def = multi_worker_test_base._create_cluster(
num_workers=0, num_ps=1)
cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
with self.assertRaisesRegexp(ValueError,
"There must be at least one worker."):
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testLessThanOnePs(self):
cluster_def = multi_worker_test_base._create_cluster(
num_workers=1, num_ps=0)
cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def),
rpc_layer="grpc",
task_type="worker",
task_id=0)
with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)