Make Distribution Strategy more friendly with evaluator job.
PiperOrigin-RevId: 238719371
This commit is contained in:
parent
32d8905c37
commit
c0886e70d1
@ -251,7 +251,7 @@ class DistributeCoordinatorIntegrationTest(
|
||||
|
||||
def _get_strategy_object(self, strategy_cls):
|
||||
if strategy_cls == mirrored_strategy.CoreMirroredStrategy:
|
||||
return strategy_cls(mirrored_strategy.all_local_devices())
|
||||
return strategy_cls()
|
||||
else:
|
||||
return strategy_cls(num_gpus_per_worker=context.num_gpus())
|
||||
|
||||
@ -269,6 +269,7 @@ class DistributeCoordinatorIntegrationTest(
|
||||
mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.CoreMirroredStrategy,
|
||||
parameter_server_strategy.ParameterServerStrategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||
],
|
||||
required_gpus=[0, 1]))
|
||||
def test_complete_flow_standalone_client(self, train_distribute_cls,
|
||||
@ -371,9 +372,11 @@ class DistributeCoordinatorIntegrationTest(
|
||||
parameter_server_strategy.ParameterServerStrategy,
|
||||
],
|
||||
eval_distribute_cls=[
|
||||
None, mirrored_strategy.MirroredStrategy,
|
||||
None,
|
||||
mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.CoreMirroredStrategy,
|
||||
parameter_server_strategy.ParameterServerStrategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||
],
|
||||
required_gpus=[0, 1]))
|
||||
def test_complete_flow_independent_worker_between_graph(
|
||||
|
@ -72,7 +72,8 @@ def _create_cluster(num_workers,
|
||||
has_eval=False,
|
||||
protocol='grpc',
|
||||
worker_config=None,
|
||||
ps_config=None):
|
||||
ps_config=None,
|
||||
eval_config=None):
|
||||
"""Creates and starts local servers and returns the cluster_spec dict."""
|
||||
if _portpicker_import_error:
|
||||
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||
@ -124,7 +125,7 @@ def _create_cluster(num_workers,
|
||||
job_name='evaluator',
|
||||
protocol=protocol,
|
||||
task_index=0,
|
||||
config=worker_config,
|
||||
config=eval_config,
|
||||
start=True)
|
||||
|
||||
return cluster_dict
|
||||
@ -153,6 +154,9 @@ def create_in_process_cluster(num_workers,
|
||||
ps_config = config_pb2.ConfigProto()
|
||||
ps_config.device_count['GPU'] = 0
|
||||
|
||||
eval_config = config_pb2.ConfigProto()
|
||||
eval_config.experimental.collective_group_leader = ''
|
||||
|
||||
# Create in-process servers. Once an in-process tensorflow server is created,
|
||||
# there is no way to terminate it. So we create one cluster per test process.
|
||||
# We could've started the server in another process, we could then kill that
|
||||
@ -169,6 +173,7 @@ def create_in_process_cluster(num_workers,
|
||||
has_eval=has_eval,
|
||||
worker_config=worker_config,
|
||||
ps_config=ps_config,
|
||||
eval_config=eval_config,
|
||||
protocol='grpc')
|
||||
|
||||
|
||||
|
@ -165,16 +165,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
task_id = cluster_resolver.task_id
|
||||
if task_type is None or task_id is None:
|
||||
raise ValueError("When `cluster_spec` is given, you must also specify "
|
||||
"`task_type` and `task_id` in the `cluster_resolver`.")
|
||||
if task_type not in ("chief", "worker"):
|
||||
raise ValueError(
|
||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||
"\"worker\"." % task_type)
|
||||
"`task_type` and `task_id`.")
|
||||
|
||||
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
|
||||
if not self._num_workers:
|
||||
raise ValueError("No `worker` or `chief` tasks can be found in "
|
||||
"`cluster_spec`.")
|
||||
raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
|
||||
"in `cluster_spec`.")
|
||||
|
||||
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
|
||||
task_id)
|
||||
@ -410,15 +406,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
# Collective group leader is needed for collective ops to coordinate
|
||||
# workers.
|
||||
if "chief" in self._cluster_spec.jobs:
|
||||
updated_config.experimental.collective_group_leader = (
|
||||
"/job:chief/replica:0/task:0")
|
||||
else:
|
||||
if "worker" not in self._cluster_spec.jobs:
|
||||
raise ValueError(
|
||||
"You must have `chief` or `worker` jobs in the `cluster_spec`.")
|
||||
updated_config.experimental.collective_group_leader = (
|
||||
"/job:worker/replica:0/task:0")
|
||||
updated_config.experimental.collective_group_leader = (
|
||||
multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
|
||||
self._task_id))
|
||||
|
||||
# The device filters prevent communication between workers.
|
||||
del updated_config.device_filters[:]
|
||||
|
@ -367,6 +367,9 @@ def _split_cluster_for_evaluator(cluster_spec, task_type):
|
||||
# distribution strategies and as a result ops in the evalauator task may have
|
||||
# unspecified devices. Those ops may end up on other tasks if we don't split
|
||||
# the cluster.
|
||||
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
||||
# you can equivalently set device filters to split clusters. This is already
|
||||
# done by distribution strategy's `update_config_proto` method.
|
||||
new_cluster_spec = multi_worker_util.normalize_cluster_spec(
|
||||
cluster_spec).as_dict()
|
||||
if task_type == _TaskType.EVALUATOR:
|
||||
|
@ -49,11 +49,13 @@ def normalize_cluster_spec(cluster_spec):
|
||||
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
"""Validates `cluster_spec`.
|
||||
|
||||
It checks
|
||||
1) whether there is such a task type as `task_type` in the
|
||||
`cluster_spec`.
|
||||
2) whether there is at most one "chief" job.
|
||||
3) whether the `task_id` is smaller than the number of `task_type`.
|
||||
It checks:
|
||||
1) task type is one of "chief", "worker" or "evaluator".
|
||||
2) whether there is such a task type as `task_type` in the `cluster_spec`.
|
||||
3) whether there is at most one "chief" job.
|
||||
4) whether there is at most one "evaluator" job.
|
||||
5) whether the `task_id` is smaller than the number of tasks for that
|
||||
particular `task_type`.
|
||||
|
||||
Args:
|
||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||
@ -63,10 +65,20 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
ValueError: if `cluster_spec` fails any check.
|
||||
"""
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
||||
raise ValueError(
|
||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
||||
|
||||
if task_type and task_type not in cluster_spec:
|
||||
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||
|
||||
if len(cluster_spec.get("chief", [])) > 1:
|
||||
raise ValueError("There must be at most one 'chief' job.")
|
||||
|
||||
if len(cluster_spec.get("evaluator", [])) > 1:
|
||||
raise ValueError("There must be at most one 'evaluator' job.")
|
||||
|
||||
if task_id >= len(cluster_spec[task_type]):
|
||||
raise ValueError(
|
||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||
@ -75,6 +87,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
def is_chief(cluster_spec, task_type, task_id):
|
||||
"""Returns whether the given task is chief in the cluster.
|
||||
|
||||
Since there is at most one evaluator and the evaluator itself should be
|
||||
independent of the training cluster, the evaluator job is also a chief job on
|
||||
its own.
|
||||
|
||||
Args:
|
||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
||||
cluster configurations.
|
||||
@ -91,7 +107,7 @@ def is_chief(cluster_spec, task_type, task_id):
|
||||
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||
|
||||
if task_type == "chief":
|
||||
if task_type == "chief" or task_type == "evaluator":
|
||||
return True
|
||||
|
||||
# If chief not in the cluster_spec, use the first worker as chief. This is
|
||||
@ -101,6 +117,40 @@ def is_chief(cluster_spec, task_type, task_id):
|
||||
return False
|
||||
|
||||
|
||||
def collective_leader(cluster_spec, task_type, task_id):
|
||||
"""Return the job name for the leader of for collective ops.
|
||||
|
||||
Args:
|
||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
||||
cluster configurations.
|
||||
task_type: the task type in the cluster.
|
||||
task_id: the task id in the cluster.
|
||||
|
||||
Returns:
|
||||
a string indicating the leader job name or empty string if no need to set
|
||||
leader job.
|
||||
"""
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec)
|
||||
|
||||
# No need to set collective leader for local.
|
||||
if not cluster_spec.as_dict():
|
||||
return ""
|
||||
|
||||
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||
|
||||
# Only one evaluator, so no need to set collective leader.
|
||||
if task_type == "evaluator":
|
||||
return ""
|
||||
|
||||
# Use chief if chief is in the cluster.
|
||||
if "chief" in cluster_spec.jobs:
|
||||
return "/job:chief/replica:0/task:0"
|
||||
|
||||
# Use worker 0 if no chief job.
|
||||
assert "worker" in cluster_spec.jobs
|
||||
return "/job:worker/replica:0/task:0"
|
||||
|
||||
|
||||
def worker_count(cluster_spec, task_type):
|
||||
"""Returns the number of workers in the cluster."""
|
||||
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
|
||||
|
@ -102,6 +102,14 @@ class IsChiefTest(test.TestCase):
|
||||
ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
|
||||
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
||||
|
||||
def testEvaluatorIsChief(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:1234"],
|
||||
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||
"evaluator": ["127.0.0.1:2019"]
|
||||
}
|
||||
self.assertTrue(multi_worker_util.is_chief(cluster_spec, "evaluator", 0))
|
||||
|
||||
|
||||
class NumWorkersTest(test.TestCase):
|
||||
|
||||
@ -192,5 +200,42 @@ class IdInClusterTest(test.TestCase):
|
||||
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
|
||||
|
||||
|
||||
class CollectiveLeaderTest(test.TestCase):
|
||||
|
||||
def testChiefAsLeader(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:1234"],
|
||||
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
self.assertEqual(
|
||||
multi_worker_util.collective_leader(cluster_spec, "worker", 0),
|
||||
"/job:chief/replica:0/task:0")
|
||||
|
||||
def testWorkerAsLeader(self):
|
||||
cluster_spec = {
|
||||
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||
}
|
||||
self.assertEqual(
|
||||
multi_worker_util.collective_leader(cluster_spec, "worker", 1),
|
||||
"/job:worker/replica:0/task:0")
|
||||
|
||||
def testLeaderForEvaluator(self):
|
||||
cluster_spec = {
|
||||
"chief": ["127.0.0.1:1234"],
|
||||
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"],
|
||||
"evaluator": ["127.0.0.1:2019"]
|
||||
}
|
||||
self.assertEqual(
|
||||
multi_worker_util.collective_leader(cluster_spec, "evaluator", 0), "")
|
||||
|
||||
def testLocalLeader(self):
|
||||
cluster_spec = {}
|
||||
self.assertEqual(
|
||||
multi_worker_util.collective_leader(cluster_spec, None, 0), "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -502,11 +502,13 @@ class ParameterServerStrategyExtended(
|
||||
assert self._task_id is not None
|
||||
|
||||
# The device filters prevent communication between workers.
|
||||
if self._task_type not in ["chief", "worker"]:
|
||||
return updated_config
|
||||
del updated_config.device_filters[:]
|
||||
updated_config.device_filters.extend(
|
||||
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
|
||||
if self._task_type in ["chief", "worker"]:
|
||||
updated_config.device_filters.extend(
|
||||
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
|
||||
elif self._task_type == "evaluator":
|
||||
updated_config.device_filters.append(
|
||||
"/job:%s/task:%d" % (self._task_type, self._task_id))
|
||||
return updated_config
|
||||
|
||||
@property
|
||||
|
Loading…
x
Reference in New Issue
Block a user