diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py index c804ed9b8bc..4d89b2fab08 100644 --- a/tensorflow/python/distribute/multi_worker_util.py +++ b/tensorflow/python/distribute/multi_worker_util.py @@ -53,7 +53,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id): It checks: 0) None of `cluster_spec`, `task_type`, and `task_id` is `None`. 1) task type is one of "chief", "worker" or "evaluator". - 2) whether there is such a task type as `task_type` in the `cluster_spec`. + 2) whether there is such a task type as `task_type` in the `cluster_spec`. The + only exception is `evaluator`. In other words, it is still a valid + configuration when `task_type` is `evaluator` but it doesn't appear in + `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator. 3) whether there is at most one "chief" job. 4) whether there is at most one "evaluator" job. 5) whether the `task_id` is smaller than the number of tasks for that @@ -76,7 +79,7 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id): "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\", \"evaluator\" and \"ps\"." % task_type) - if task_type and task_type not in cluster_spec: + if task_type and task_type not in cluster_spec and task_type != "evaluator": raise ValueError("`task_type` %r not found in cluster_spec." % task_type) if len(cluster_spec.get("chief", [])) > 1: @@ -85,7 +88,8 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id): if len(cluster_spec.get("evaluator", [])) > 1: raise ValueError("There must be at most one 'evaluator' job.") - if task_id >= len(cluster_spec[task_type]): + # The `evaluator` job is allowed to be missing in `cluster_spec`. + if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]): raise ValueError( "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) diff --git a/tensorflow/python/distribute/multi_worker_util_test.py b/tensorflow/python/distribute/multi_worker_util_test.py index dbe57b24e08..6a51e71ded7 100644 --- a/tensorflow/python/distribute/multi_worker_util_test.py +++ b/tensorflow/python/distribute/multi_worker_util_test.py @@ -237,5 +237,30 @@ class CollectiveLeaderTest(test.TestCase): multi_worker_util.collective_leader(cluster_spec, None, 0), "") +# Most of the validation logic is tested by above tests except for some. +class ClusterSpecValidationTest(test.TestCase): + + def testEvaluatorNotInCluster(self): + cluster_spec = { + "chief": ["127.0.0.1:1234"], + "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], + "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] + } + multi_worker_util._validate_cluster_spec(cluster_spec, "chief", 0) + multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0) + multi_worker_util._validate_cluster_spec(cluster_spec, "ps", 0) + multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0) + + def testWorkerNotInCluster(self): + cluster_spec = { + "chief": ["127.0.0.1:1234"], + "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] + } + multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0) + with self.assertRaisesRegexp( + ValueError, "`task_type` 'worker' not found in cluster_spec."): + multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0) + + if __name__ == "__main__": test.main()