diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py index 4d89b2fab08..943605fac20 100644 --- a/tensorflow/python/distribute/multi_worker_util.py +++ b/tensorflow/python/distribute/multi_worker_util.py @@ -46,13 +46,21 @@ def normalize_cluster_spec(cluster_spec): return cluster_spec -# TODO(yuefengz): add more validations. -def _validate_cluster_spec(cluster_spec, task_type, task_id): +def task_count(cluster_spec, task_type): + try: + return cluster_spec.num_tasks(task_type) + except ValueError: + return 0 + + +def _validate_cluster_spec(cluster_spec, + task_type, + task_id): """Validates `cluster_spec`. It checks: - 0) None of `cluster_spec`, `task_type`, and `task_id` is `None`. - 1) task type is one of "chief", "worker" or "evaluator". + 1) task type is one of "chief", "worker", "ps", "evaluator", or not provided + (None). 2) whether there is such a task type as `task_type` in the `cluster_spec`. The only exception is `evaluator`. In other words, it is still a valid configuration when `task_type` is `evaluator` but it doesn't appear in @@ -65,31 +73,38 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id): Args: cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. task_type: string indicating the type of the task. - task_id: task_id: the id of the `task_type` in this cluster. - Throws: + task_id: the id of the `task_type` in this cluster. + + Raises: ValueError: if `cluster_spec` fails any check. """ - if cluster_spec is None or task_type is None or task_id is None: - raise ValueError( - "None of `cluster_spec`, `task_type`, and `task_id` should be `None`.") + allowed_task_types = ("chief", "worker", "evaluator", "ps", None) - cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() - if task_type not in ("chief", "worker", "evaluator", "ps"): - raise ValueError( - "Unrecognized task_type: %r, valid task types are: \"chief\", " - "\"worker\", \"evaluator\" and \"ps\"." % task_type) + cluster_spec = normalize_cluster_spec(cluster_spec) - if task_type and task_type not in cluster_spec and task_type != "evaluator": + if any([job not in allowed_task_types for job in cluster_spec.jobs]): + raise ValueError("Disallowed task type found in cluster spec. Allowed " + "types are {} and the cluster spec is {}.".format( + allowed_task_types, cluster_spec)) + + if task_type not in allowed_task_types: + raise ValueError( + "Unrecognized task_type: {}, valid task types are: {}".format( + task_type, allowed_task_types)) + + if (task_type and task_type not in cluster_spec.jobs and + task_type != "evaluator"): raise ValueError("`task_type` %r not found in cluster_spec." % task_type) - if len(cluster_spec.get("chief", [])) > 1: + if task_count(cluster_spec, "chief") > 1: raise ValueError("There must be at most one 'chief' job.") - if len(cluster_spec.get("evaluator", [])) > 1: + if task_count(cluster_spec, "evaluator") > 1: raise ValueError("There must be at most one 'evaluator' job.") # The `evaluator` job is allowed to be missing in `cluster_spec`. - if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]): + if task_type in cluster_spec.jobs and task_id >= task_count( + cluster_spec, task_type): raise ValueError( "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index 452f89a9425..9cff2d789e6 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -26,6 +26,7 @@ import os from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import remote @@ -486,22 +487,19 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access raise NotImplementedError("Multi-gpu is not supported yet.") + cluster_spec = cluster_resolver.cluster_spec() + # The following checks if the task types are allowed (chief, ps, worker). - disallowed_task_type_error_str = ( - "Disallowed task type found in " - "`tf.distribute.cluster_resolver.ClusterResolver` provided to " - "`tf.distribute.experimental.ParameterServerStrategy`. Allowed types " - "are {},".format(ALLOWED_TASK_TYPES)) - if any([ - job not in ALLOWED_TASK_TYPES - for job in cluster_resolver.cluster_spec().jobs - ]): - raise ValueError("{} and the cluster spec is {}.".format( - disallowed_task_type_error_str, cluster_resolver.cluster_spec())) - if (cluster_resolver.task_type and - cluster_resolver.task_type not in ALLOWED_TASK_TYPES): - raise ValueError("{} and current task type is {}.".format( - disallowed_task_type_error_str, cluster_resolver.task_type)) + multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access + cluster_spec, + cluster_resolver.task_type, + cluster_resolver.task_id) + + if multi_worker_util.task_count(cluster_spec, "ps") < 1: + raise ValueError("There must be at least one ps.") + + if multi_worker_util.task_count(cluster_spec, "worker") < 1: + raise ValueError("There must be at least one worker.") class ParameterServerStrategyV2Extended( diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py index 4e2ad3e70fe..7e682d07c08 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py @@ -412,7 +412,47 @@ class ClusterTypeNameTest(test.TestCase): ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") - with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"): + with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): + parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) + + def testMoreThanOneChief(self): + cluster_def = multi_worker_test_base._create_cluster( + num_workers=1, num_ps=1) + chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)] + cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports] + cluster_resolver = SimpleClusterResolver( + ClusterSpec(cluster_def), + rpc_layer="grpc", + task_type="chief", + task_id=1) + with self.assertRaisesRegexp(ValueError, + "There must be at most one 'chief' job."): + parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) + + def testLessThanOneWorker(self): + cluster_def = multi_worker_test_base._create_cluster( + num_workers=0, num_ps=1) + cluster_def["chief"] = [ + "localhost:%d" % multi_worker_test_base.pick_unused_port() + ] + cluster_resolver = SimpleClusterResolver( + ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) + with self.assertRaisesRegexp(ValueError, + "There must be at least one worker."): + parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) + + def testLessThanOnePs(self): + cluster_def = multi_worker_test_base._create_cluster( + num_workers=1, num_ps=0) + cluster_def["chief"] = [ + "localhost:%d" % multi_worker_test_base.pick_unused_port() + ] + cluster_resolver = SimpleClusterResolver( + ClusterSpec(cluster_def), + rpc_layer="grpc", + task_type="worker", + task_id=0) + with self.assertRaisesRegexp(ValueError, "There must be at least one ps."): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)