From faf44b5391e0e9925efa66f3fc7521955962c091 Mon Sep 17 00:00:00 2001
From: Rick Chao <rchao@google.com>
Date: Tue, 10 Nov 2020 18:31:02 -0800
Subject: [PATCH] PSv2: Check that there is no more than one chief, and at
 least one ps/worker. Combine the validation logic with multi_worker_util.

PiperOrigin-RevId: 341740027
Change-Id: I7e3125f8eaefb12c96f37b7fa3a54afbfc1e4334
---
 .../python/distribute/multi_worker_util.py    | 51 ++++++++++++-------
 .../parameter_server_strategy_v2.py           | 28 +++++-----
 .../parameter_server_strategy_v2_test.py      | 42 ++++++++++++++-
 3 files changed, 87 insertions(+), 34 deletions(-)

diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py
index 4d89b2fab08..943605fac20 100644
--- a/tensorflow/python/distribute/multi_worker_util.py
+++ b/tensorflow/python/distribute/multi_worker_util.py
@@ -46,13 +46,21 @@ def normalize_cluster_spec(cluster_spec):
   return cluster_spec
 
 
-# TODO(yuefengz): add more validations.
-def _validate_cluster_spec(cluster_spec, task_type, task_id):
+def task_count(cluster_spec, task_type):
+  try:
+    return cluster_spec.num_tasks(task_type)
+  except ValueError:
+    return 0
+
+
+def _validate_cluster_spec(cluster_spec,
+                           task_type,
+                           task_id):
   """Validates `cluster_spec`.
 
   It checks:
-  0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
-  1) task type is one of "chief", "worker" or "evaluator".
+  1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
+     (None).
   2) whether there is such a task type as `task_type` in the `cluster_spec`. The
      only exception is `evaluator`. In other words, it is still a valid
      configuration when `task_type` is `evaluator` but it doesn't appear in
@@ -65,31 +73,38 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
   Args:
     cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
     task_type: string indicating the type of the task.
-    task_id: task_id: the id of the `task_type` in this cluster.
-  Throws:
+    task_id: the id of the `task_type` in this cluster.
+
+  Raises:
     ValueError: if `cluster_spec` fails any check.
   """
-  if cluster_spec is None or task_type is None or task_id is None:
-    raise ValueError(
-        "None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
+  allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
 
-  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
-  if task_type not in ("chief", "worker", "evaluator", "ps"):
-    raise ValueError(
-        "Unrecognized task_type: %r, valid task types are: \"chief\", "
-        "\"worker\", \"evaluator\" and \"ps\"." % task_type)
+  cluster_spec = normalize_cluster_spec(cluster_spec)
 
-  if task_type and task_type not in cluster_spec and task_type != "evaluator":
+  if any([job not in allowed_task_types for job in cluster_spec.jobs]):
+    raise ValueError("Disallowed task type found in cluster spec. Allowed "
+                     "types are {} and the cluster spec is {}.".format(
+                         allowed_task_types, cluster_spec))
+
+  if task_type not in allowed_task_types:
+    raise ValueError(
+        "Unrecognized task_type: {}, valid task types are: {}".format(
+            task_type, allowed_task_types))
+
+  if (task_type and task_type not in cluster_spec.jobs and
+      task_type != "evaluator"):
     raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
 
-  if len(cluster_spec.get("chief", [])) > 1:
+  if task_count(cluster_spec, "chief") > 1:
     raise ValueError("There must be at most one 'chief' job.")
 
-  if len(cluster_spec.get("evaluator", [])) > 1:
+  if task_count(cluster_spec, "evaluator") > 1:
     raise ValueError("There must be at most one 'evaluator' job.")
 
   # The `evaluator` job is allowed to be missing in `cluster_spec`.
-  if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
+  if task_type in cluster_spec.jobs and task_id >= task_count(
+      cluster_spec, task_type):
     raise ValueError(
         "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
 
diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py
index 452f89a9425..9cff2d789e6 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_v2.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py
@@ -26,6 +26,7 @@ import os
 
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribute_utils
+from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.distribute import parameter_server_strategy
 from tensorflow.python.distribute import sharded_variable
 from tensorflow.python.eager import remote
@@ -486,22 +487,19 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
     if self.extended._num_gpus_per_worker > 1:  # pylint: disable=protected-access
       raise NotImplementedError("Multi-gpu is not supported yet.")
 
+    cluster_spec = cluster_resolver.cluster_spec()
+
     # The following checks if the task types are allowed (chief, ps, worker).
-    disallowed_task_type_error_str = (
-        "Disallowed task type found in "
-        "`tf.distribute.cluster_resolver.ClusterResolver` provided to "
-        "`tf.distribute.experimental.ParameterServerStrategy`. Allowed types "
-        "are {},".format(ALLOWED_TASK_TYPES))
-    if any([
-        job not in ALLOWED_TASK_TYPES
-        for job in cluster_resolver.cluster_spec().jobs
-    ]):
-      raise ValueError("{} and the cluster spec is {}.".format(
-          disallowed_task_type_error_str, cluster_resolver.cluster_spec()))
-    if (cluster_resolver.task_type and
-        cluster_resolver.task_type not in ALLOWED_TASK_TYPES):
-      raise ValueError("{} and current task type is {}.".format(
-          disallowed_task_type_error_str, cluster_resolver.task_type))
+    multi_worker_util._validate_cluster_spec(  # pylint: disable=protected-access
+        cluster_spec,
+        cluster_resolver.task_type,
+        cluster_resolver.task_id)
+
+    if multi_worker_util.task_count(cluster_spec, "ps") < 1:
+      raise ValueError("There must be at least one ps.")
+
+    if multi_worker_util.task_count(cluster_spec, "worker") < 1:
+      raise ValueError("There must be at least one worker.")
 
 
 class ParameterServerStrategyV2Extended(
diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
index 4e2ad3e70fe..7e682d07c08 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
@@ -412,7 +412,47 @@ class ClusterTypeNameTest(test.TestCase):
     ]
     cluster_resolver = SimpleClusterResolver(
         ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
-    with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
+    with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
+      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
+
+  def testMoreThanOneChief(self):
+    cluster_def = multi_worker_test_base._create_cluster(
+        num_workers=1, num_ps=1)
+    chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
+    cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
+    cluster_resolver = SimpleClusterResolver(
+        ClusterSpec(cluster_def),
+        rpc_layer="grpc",
+        task_type="chief",
+        task_id=1)
+    with self.assertRaisesRegexp(ValueError,
+                                 "There must be at most one 'chief' job."):
+      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
+
+  def testLessThanOneWorker(self):
+    cluster_def = multi_worker_test_base._create_cluster(
+        num_workers=0, num_ps=1)
+    cluster_def["chief"] = [
+        "localhost:%d" % multi_worker_test_base.pick_unused_port()
+    ]
+    cluster_resolver = SimpleClusterResolver(
+        ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
+    with self.assertRaisesRegexp(ValueError,
+                                 "There must be at least one worker."):
+      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
+
+  def testLessThanOnePs(self):
+    cluster_def = multi_worker_test_base._create_cluster(
+        num_workers=1, num_ps=0)
+    cluster_def["chief"] = [
+        "localhost:%d" % multi_worker_test_base.pick_unused_port()
+    ]
+    cluster_resolver = SimpleClusterResolver(
+        ClusterSpec(cluster_def),
+        rpc_layer="grpc",
+        task_type="worker",
+        task_id=0)
+    with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
       parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)