Allow evaluator not in cluster_spec, to be consistent with legacy Estimator.

PiperOrigin-RevId: 281833366
Change-Id: Ic580172ba5ec038e246028031ec277b18f31ea56
This commit is contained in:
Yuefeng Zhou 2019-11-21 14:30:26 -08:00 committed by TensorFlower Gardener
parent 75af7b4750
commit e37d4d2b68
2 changed files with 32 additions and 3 deletions

View File

@ -53,7 +53,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
It checks:
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
1) task type is one of "chief", "worker" or "evaluator".
2) whether there is such a task type as `task_type` in the `cluster_spec`.
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
only exception is `evaluator`. In other words, it is still a valid
configuration when `task_type` is `evaluator` but it doesn't appear in
`cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
3) whether there is at most one "chief" job.
4) whether there is at most one "evaluator" job.
5) whether the `task_id` is smaller than the number of tasks for that
@ -76,7 +79,7 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
if task_type and task_type not in cluster_spec:
if task_type and task_type not in cluster_spec and task_type != "evaluator":
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
if len(cluster_spec.get("chief", [])) > 1:
@ -85,7 +88,8 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
if len(cluster_spec.get("evaluator", [])) > 1:
raise ValueError("There must be at most one 'evaluator' job.")
if task_id >= len(cluster_spec[task_type]):
# The `evaluator` job is allowed to be missing in `cluster_spec`.
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
raise ValueError(
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))

View File

@ -237,5 +237,30 @@ class CollectiveLeaderTest(test.TestCase):
multi_worker_util.collective_leader(cluster_spec, None, 0), "")
# Most of the validation logic is tested by above tests except for some.
class ClusterSpecValidationTest(test.TestCase):
def testEvaluatorNotInCluster(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
multi_worker_util._validate_cluster_spec(cluster_spec, "chief", 0)
multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
multi_worker_util._validate_cluster_spec(cluster_spec, "ps", 0)
multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
def testWorkerNotInCluster(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
with self.assertRaisesRegexp(
ValueError, "`task_type` 'worker' not found in cluster_spec."):
multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
if __name__ == "__main__":
test.main()