Allow evaluator not in cluster_spec, to be consistent with legacy Estimator.
PiperOrigin-RevId: 281833366 Change-Id: Ic580172ba5ec038e246028031ec277b18f31ea56
This commit is contained in:
parent
75af7b4750
commit
e37d4d2b68
tensorflow/python/distribute
@ -53,7 +53,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
It checks:
|
||||
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
|
||||
1) task type is one of "chief", "worker" or "evaluator".
|
||||
2) whether there is such a task type as `task_type` in the `cluster_spec`.
|
||||
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
||||
only exception is `evaluator`. In other words, it is still a valid
|
||||
configuration when `task_type` is `evaluator` but it doesn't appear in
|
||||
`cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
|
||||
3) whether there is at most one "chief" job.
|
||||
4) whether there is at most one "evaluator" job.
|
||||
5) whether the `task_id` is smaller than the number of tasks for that
|
||||
@ -76,7 +79,7 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
||||
|
||||
if task_type and task_type not in cluster_spec:
|
||||
if task_type and task_type not in cluster_spec and task_type != "evaluator":
|
||||
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||
|
||||
if len(cluster_spec.get("chief", [])) > 1:
|
||||
@ -85,7 +88,8 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
if len(cluster_spec.get("evaluator", [])) > 1:
|
||||
raise ValueError("There must be at most one 'evaluator' job.")
|
||||
|
||||
if task_id >= len(cluster_spec[task_type]):
|
||||
# The `evaluator` job is allowed to be missing in `cluster_spec`.
|
||||
if task_type in cluster_spec and task_id >= len(cluster_spec[task_type]):
|
||||
raise ValueError(
|
||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||
|
||||
|
@ -237,5 +237,30 @@ class CollectiveLeaderTest(test.TestCase):
|
||||
multi_worker_util.collective_leader(cluster_spec, None, 0), "")
|
||||
|
||||
|
||||
# Most of the validation logic is tested by above tests except for some.
|
||||
class ClusterSpecValidationTest(test.TestCase):
|
||||
|
||||
def testEvaluatorNotInCluster(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:1234"],
|
||||
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "chief", 0)
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "ps", 0)
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
|
||||
|
||||
def testWorkerNotInCluster(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:1234"],
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`task_type` 'worker' not found in cluster_spec."):
|
||||
multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user