PSv2: Check that there is no more than one chief, and at least one ps/worker. Combine the validation logic with multi_worker_util.
PiperOrigin-RevId: 341740027 Change-Id: I7e3125f8eaefb12c96f37b7fa3a54afbfc1e4334
This commit is contained in:
parent
229a8af422
commit
faf44b5391
tensorflow/python/distribute
@ -46,13 +46,21 @@ def normalize_cluster_spec(cluster_spec):
|
|||||||
return cluster_spec
|
return cluster_spec
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): add more validations.
|
def task_count(cluster_spec, task_type):
|
||||||
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
try:
|
||||||
|
return cluster_spec.num_tasks(task_type)
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_cluster_spec(cluster_spec,
|
||||||
|
task_type,
|
||||||
|
task_id):
|
||||||
"""Validates `cluster_spec`.
|
"""Validates `cluster_spec`.
|
||||||
|
|
||||||
It checks:
|
It checks:
|
||||||
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
|
1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
|
||||||
1) task type is one of "chief", "worker" or "evaluator".
|
(None).
|
||||||
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
||||||
only exception is `evaluator`. In other words, it is still a valid
|
only exception is `evaluator`. In other words, it is still a valid
|
||||||
configuration when `task_type` is `evaluator` but it doesn't appear in
|
configuration when `task_type` is `evaluator` but it doesn't appear in
|
||||||
@ -65,31 +73,38 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
|||||||
Args:
|
Args:
|
||||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||||
task_type: string indicating the type of the task.
|
task_type: string indicating the type of the task.
|
||||||
task_id: task_id: the id of the `task_type` in this cluster.
|
task_id: the id of the `task_type` in this cluster.
|
||||||
Throws:
|
|
||||||
|
Raises:
|
||||||
ValueError: if `cluster_spec` fails any check.
|
ValueError: if `cluster_spec` fails any check.
|
||||||
"""
|
"""
|
||||||
if cluster_spec is None or task_type is None or task_id is None:
|
allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
|
||||||
raise ValueError(
|
|
||||||
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
|
|
||||||
|
|
||||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
cluster_spec = normalize_cluster_spec(cluster_spec)
|
||||||
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
|
||||||
raise ValueError(
|
|
||||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
|
||||||
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
|
||||||
|
|
||||||
if task_type and task_type not in cluster_spec and task_type != "evaluator":
|
if any([job not in allowed_task_types for job in cluster_spec.jobs]):
|
||||||
|
raise ValueError("Disallowed task type found in cluster spec. Allowed "
|
||||||
|
"types are {} and the cluster spec is {}.".format(
|
||||||
|
allowed_task_types, cluster_spec))
|
||||||
|
|
||||||
|
if task_type not in allowed_task_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized task_type: {}, valid task types are: {}".format(
|
||||||
|
task_type, allowed_task_types))
|
||||||
|
|
||||||
|
if (task_type and task_type not in cluster_spec.jobs and
|
||||||
|
task_type != "evaluator"):
|
||||||
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||||
|
|
||||||
if len(cluster_spec.get("chief", [])) > 1:
|
if task_count(cluster_spec, "chief") > 1:
|
||||||
raise ValueError("There must be at most one 'chief' job.")
|
raise ValueError("There must be at most one 'chief' job.")
|
||||||
|
|
||||||
if len(cluster_spec.get("evaluator", [])) > 1:
|
if task_count(cluster_spec, "evaluator") > 1:
|
||||||
raise ValueError("There must be at most one 'evaluator' job.")
|
raise ValueError("There must be at most one 'evaluator' job.")
|
||||||
|
|
||||||
# The `evaluator` job is allowed to be missing in `cluster_spec`.
|
# The `evaluator` job is allowed to be missing in `cluster_spec`.
|
||||||
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
|
if task_type in cluster_spec.jobs and task_id >= task_count(
|
||||||
|
cluster_spec, task_type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ import os
|
|||||||
|
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribute_utils
|
from tensorflow.python.distribute import distribute_utils
|
||||||
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import parameter_server_strategy
|
from tensorflow.python.distribute import parameter_server_strategy
|
||||||
from tensorflow.python.distribute import sharded_variable
|
from tensorflow.python.distribute import sharded_variable
|
||||||
from tensorflow.python.eager import remote
|
from tensorflow.python.eager import remote
|
||||||
@ -486,22 +487,19 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
|||||||
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
|
if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access
|
||||||
raise NotImplementedError("Multi-gpu is not supported yet.")
|
raise NotImplementedError("Multi-gpu is not supported yet.")
|
||||||
|
|
||||||
|
cluster_spec = cluster_resolver.cluster_spec()
|
||||||
|
|
||||||
# The following checks if the task types are allowed (chief, ps, worker).
|
# The following checks if the task types are allowed (chief, ps, worker).
|
||||||
disallowed_task_type_error_str = (
|
multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access
|
||||||
"Disallowed task type found in "
|
cluster_spec,
|
||||||
"`tf.distribute.cluster_resolver.ClusterResolver` provided to "
|
cluster_resolver.task_type,
|
||||||
"`tf.distribute.experimental.ParameterServerStrategy`. Allowed types "
|
cluster_resolver.task_id)
|
||||||
"are {},".format(ALLOWED_TASK_TYPES))
|
|
||||||
if any([
|
if multi_worker_util.task_count(cluster_spec, "ps") < 1:
|
||||||
job not in ALLOWED_TASK_TYPES
|
raise ValueError("There must be at least one ps.")
|
||||||
for job in cluster_resolver.cluster_spec().jobs
|
|
||||||
]):
|
if multi_worker_util.task_count(cluster_spec, "worker") < 1:
|
||||||
raise ValueError("{} and the cluster spec is {}.".format(
|
raise ValueError("There must be at least one worker.")
|
||||||
disallowed_task_type_error_str, cluster_resolver.cluster_spec()))
|
|
||||||
if (cluster_resolver.task_type and
|
|
||||||
cluster_resolver.task_type not in ALLOWED_TASK_TYPES):
|
|
||||||
raise ValueError("{} and current task type is {}.".format(
|
|
||||||
disallowed_task_type_error_str, cluster_resolver.task_type))
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterServerStrategyV2Extended(
|
class ParameterServerStrategyV2Extended(
|
||||||
|
@ -412,7 +412,47 @@ class ClusterTypeNameTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
cluster_resolver = SimpleClusterResolver(
|
cluster_resolver = SimpleClusterResolver(
|
||||||
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
|
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
|
||||||
with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
|
with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
|
||||||
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||||
|
|
||||||
|
def testMoreThanOneChief(self):
|
||||||
|
cluster_def = multi_worker_test_base._create_cluster(
|
||||||
|
num_workers=1, num_ps=1)
|
||||||
|
chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
|
||||||
|
cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
|
||||||
|
cluster_resolver = SimpleClusterResolver(
|
||||||
|
ClusterSpec(cluster_def),
|
||||||
|
rpc_layer="grpc",
|
||||||
|
task_type="chief",
|
||||||
|
task_id=1)
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"There must be at most one 'chief' job."):
|
||||||
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||||
|
|
||||||
|
def testLessThanOneWorker(self):
|
||||||
|
cluster_def = multi_worker_test_base._create_cluster(
|
||||||
|
num_workers=0, num_ps=1)
|
||||||
|
cluster_def["chief"] = [
|
||||||
|
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||||
|
]
|
||||||
|
cluster_resolver = SimpleClusterResolver(
|
||||||
|
ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"There must be at least one worker."):
|
||||||
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||||
|
|
||||||
|
def testLessThanOnePs(self):
|
||||||
|
cluster_def = multi_worker_test_base._create_cluster(
|
||||||
|
num_workers=1, num_ps=0)
|
||||||
|
cluster_def["chief"] = [
|
||||||
|
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||||
|
]
|
||||||
|
cluster_resolver = SimpleClusterResolver(
|
||||||
|
ClusterSpec(cluster_def),
|
||||||
|
rpc_layer="grpc",
|
||||||
|
task_type="worker",
|
||||||
|
task_id=0)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
|
||||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user