Make Distribution Strategy more friendly with evaluator job.
PiperOrigin-RevId: 238719371
This commit is contained in:
parent
32d8905c37
commit
c0886e70d1
@ -251,7 +251,7 @@ class DistributeCoordinatorIntegrationTest(
|
|||||||
|
|
||||||
def _get_strategy_object(self, strategy_cls):
|
def _get_strategy_object(self, strategy_cls):
|
||||||
if strategy_cls == mirrored_strategy.CoreMirroredStrategy:
|
if strategy_cls == mirrored_strategy.CoreMirroredStrategy:
|
||||||
return strategy_cls(mirrored_strategy.all_local_devices())
|
return strategy_cls()
|
||||||
else:
|
else:
|
||||||
return strategy_cls(num_gpus_per_worker=context.num_gpus())
|
return strategy_cls(num_gpus_per_worker=context.num_gpus())
|
||||||
|
|
||||||
@ -269,6 +269,7 @@ class DistributeCoordinatorIntegrationTest(
|
|||||||
mirrored_strategy.MirroredStrategy,
|
mirrored_strategy.MirroredStrategy,
|
||||||
mirrored_strategy.CoreMirroredStrategy,
|
mirrored_strategy.CoreMirroredStrategy,
|
||||||
parameter_server_strategy.ParameterServerStrategy,
|
parameter_server_strategy.ParameterServerStrategy,
|
||||||
|
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||||
],
|
],
|
||||||
required_gpus=[0, 1]))
|
required_gpus=[0, 1]))
|
||||||
def test_complete_flow_standalone_client(self, train_distribute_cls,
|
def test_complete_flow_standalone_client(self, train_distribute_cls,
|
||||||
@ -371,9 +372,11 @@ class DistributeCoordinatorIntegrationTest(
|
|||||||
parameter_server_strategy.ParameterServerStrategy,
|
parameter_server_strategy.ParameterServerStrategy,
|
||||||
],
|
],
|
||||||
eval_distribute_cls=[
|
eval_distribute_cls=[
|
||||||
None, mirrored_strategy.MirroredStrategy,
|
None,
|
||||||
|
mirrored_strategy.MirroredStrategy,
|
||||||
mirrored_strategy.CoreMirroredStrategy,
|
mirrored_strategy.CoreMirroredStrategy,
|
||||||
parameter_server_strategy.ParameterServerStrategy,
|
parameter_server_strategy.ParameterServerStrategy,
|
||||||
|
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||||
],
|
],
|
||||||
required_gpus=[0, 1]))
|
required_gpus=[0, 1]))
|
||||||
def test_complete_flow_independent_worker_between_graph(
|
def test_complete_flow_independent_worker_between_graph(
|
||||||
|
@ -72,7 +72,8 @@ def _create_cluster(num_workers,
|
|||||||
has_eval=False,
|
has_eval=False,
|
||||||
protocol='grpc',
|
protocol='grpc',
|
||||||
worker_config=None,
|
worker_config=None,
|
||||||
ps_config=None):
|
ps_config=None,
|
||||||
|
eval_config=None):
|
||||||
"""Creates and starts local servers and returns the cluster_spec dict."""
|
"""Creates and starts local servers and returns the cluster_spec dict."""
|
||||||
if _portpicker_import_error:
|
if _portpicker_import_error:
|
||||||
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||||
@ -124,7 +125,7 @@ def _create_cluster(num_workers,
|
|||||||
job_name='evaluator',
|
job_name='evaluator',
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
task_index=0,
|
task_index=0,
|
||||||
config=worker_config,
|
config=eval_config,
|
||||||
start=True)
|
start=True)
|
||||||
|
|
||||||
return cluster_dict
|
return cluster_dict
|
||||||
@ -153,6 +154,9 @@ def create_in_process_cluster(num_workers,
|
|||||||
ps_config = config_pb2.ConfigProto()
|
ps_config = config_pb2.ConfigProto()
|
||||||
ps_config.device_count['GPU'] = 0
|
ps_config.device_count['GPU'] = 0
|
||||||
|
|
||||||
|
eval_config = config_pb2.ConfigProto()
|
||||||
|
eval_config.experimental.collective_group_leader = ''
|
||||||
|
|
||||||
# Create in-process servers. Once an in-process tensorflow server is created,
|
# Create in-process servers. Once an in-process tensorflow server is created,
|
||||||
# there is no way to terminate it. So we create one cluster per test process.
|
# there is no way to terminate it. So we create one cluster per test process.
|
||||||
# We could've started the server in another process, we could then kill that
|
# We could've started the server in another process, we could then kill that
|
||||||
@ -169,6 +173,7 @@ def create_in_process_cluster(num_workers,
|
|||||||
has_eval=has_eval,
|
has_eval=has_eval,
|
||||||
worker_config=worker_config,
|
worker_config=worker_config,
|
||||||
ps_config=ps_config,
|
ps_config=ps_config,
|
||||||
|
eval_config=eval_config,
|
||||||
protocol='grpc')
|
protocol='grpc')
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,16 +165,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
task_id = cluster_resolver.task_id
|
task_id = cluster_resolver.task_id
|
||||||
if task_type is None or task_id is None:
|
if task_type is None or task_id is None:
|
||||||
raise ValueError("When `cluster_spec` is given, you must also specify "
|
raise ValueError("When `cluster_spec` is given, you must also specify "
|
||||||
"`task_type` and `task_id` in the `cluster_resolver`.")
|
"`task_type` and `task_id`.")
|
||||||
if task_type not in ("chief", "worker"):
|
|
||||||
raise ValueError(
|
|
||||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
|
||||||
"\"worker\"." % task_type)
|
|
||||||
|
|
||||||
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
|
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
|
||||||
if not self._num_workers:
|
if not self._num_workers:
|
||||||
raise ValueError("No `worker` or `chief` tasks can be found in "
|
raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
|
||||||
"`cluster_spec`.")
|
"in `cluster_spec`.")
|
||||||
|
|
||||||
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
|
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
|
||||||
task_id)
|
task_id)
|
||||||
@ -410,15 +406,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
|
|
||||||
# Collective group leader is needed for collective ops to coordinate
|
# Collective group leader is needed for collective ops to coordinate
|
||||||
# workers.
|
# workers.
|
||||||
if "chief" in self._cluster_spec.jobs:
|
updated_config.experimental.collective_group_leader = (
|
||||||
updated_config.experimental.collective_group_leader = (
|
multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
|
||||||
"/job:chief/replica:0/task:0")
|
self._task_id))
|
||||||
else:
|
|
||||||
if "worker" not in self._cluster_spec.jobs:
|
|
||||||
raise ValueError(
|
|
||||||
"You must have `chief` or `worker` jobs in the `cluster_spec`.")
|
|
||||||
updated_config.experimental.collective_group_leader = (
|
|
||||||
"/job:worker/replica:0/task:0")
|
|
||||||
|
|
||||||
# The device filters prevent communication between workers.
|
# The device filters prevent communication between workers.
|
||||||
del updated_config.device_filters[:]
|
del updated_config.device_filters[:]
|
||||||
|
@ -367,6 +367,9 @@ def _split_cluster_for_evaluator(cluster_spec, task_type):
|
|||||||
# distribution strategies and as a result ops in the evalauator task may have
|
# distribution strategies and as a result ops in the evalauator task may have
|
||||||
# unspecified devices. Those ops may end up on other tasks if we don't split
|
# unspecified devices. Those ops may end up on other tasks if we don't split
|
||||||
# the cluster.
|
# the cluster.
|
||||||
|
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
||||||
|
# you can equivalently set device filters to split clusters. This is already
|
||||||
|
# done by distribution strategy's `update_config_proto` method.
|
||||||
new_cluster_spec = multi_worker_util.normalize_cluster_spec(
|
new_cluster_spec = multi_worker_util.normalize_cluster_spec(
|
||||||
cluster_spec).as_dict()
|
cluster_spec).as_dict()
|
||||||
if task_type == _TaskType.EVALUATOR:
|
if task_type == _TaskType.EVALUATOR:
|
||||||
|
@ -49,11 +49,13 @@ def normalize_cluster_spec(cluster_spec):
|
|||||||
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||||
"""Validates `cluster_spec`.
|
"""Validates `cluster_spec`.
|
||||||
|
|
||||||
It checks
|
It checks:
|
||||||
1) whether there is such a task type as `task_type` in the
|
1) task type is one of "chief", "worker" or "evaluator".
|
||||||
`cluster_spec`.
|
2) whether there is such a task type as `task_type` in the `cluster_spec`.
|
||||||
2) whether there is at most one "chief" job.
|
3) whether there is at most one "chief" job.
|
||||||
3) whether the `task_id` is smaller than the number of `task_type`.
|
4) whether there is at most one "evaluator" job.
|
||||||
|
5) whether the `task_id` is smaller than the number of tasks for that
|
||||||
|
particular `task_type`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||||
@ -63,10 +65,20 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
|||||||
ValueError: if `cluster_spec` fails any check.
|
ValueError: if `cluster_spec` fails any check.
|
||||||
"""
|
"""
|
||||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
|
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||||
|
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
|
||||||
|
|
||||||
if task_type and task_type not in cluster_spec:
|
if task_type and task_type not in cluster_spec:
|
||||||
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||||
|
|
||||||
if len(cluster_spec.get("chief", [])) > 1:
|
if len(cluster_spec.get("chief", [])) > 1:
|
||||||
raise ValueError("There must be at most one 'chief' job.")
|
raise ValueError("There must be at most one 'chief' job.")
|
||||||
|
|
||||||
|
if len(cluster_spec.get("evaluator", [])) > 1:
|
||||||
|
raise ValueError("There must be at most one 'evaluator' job.")
|
||||||
|
|
||||||
if task_id >= len(cluster_spec[task_type]):
|
if task_id >= len(cluster_spec[task_type]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||||
@ -75,6 +87,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
|||||||
def is_chief(cluster_spec, task_type, task_id):
|
def is_chief(cluster_spec, task_type, task_id):
|
||||||
"""Returns whether the given task is chief in the cluster.
|
"""Returns whether the given task is chief in the cluster.
|
||||||
|
|
||||||
|
Since there is at most one evaluator and the evaluator itself should be
|
||||||
|
independent of the training cluster, the evaluator job is also a chief job on
|
||||||
|
its own.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
||||||
cluster configurations.
|
cluster configurations.
|
||||||
@ -91,7 +107,7 @@ def is_chief(cluster_spec, task_type, task_id):
|
|||||||
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
|
|
||||||
if task_type == "chief":
|
if task_type == "chief" or task_type == "evaluator":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# If chief not in the cluster_spec, use the first worker as chief. This is
|
# If chief not in the cluster_spec, use the first worker as chief. This is
|
||||||
@ -101,6 +117,40 @@ def is_chief(cluster_spec, task_type, task_id):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def collective_leader(cluster_spec, task_type, task_id):
|
||||||
|
"""Return the job name for the leader of for collective ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
||||||
|
cluster configurations.
|
||||||
|
task_type: the task type in the cluster.
|
||||||
|
task_id: the task id in the cluster.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a string indicating the leader job name or empty string if no need to set
|
||||||
|
leader job.
|
||||||
|
"""
|
||||||
|
cluster_spec = normalize_cluster_spec(cluster_spec)
|
||||||
|
|
||||||
|
# No need to set collective leader for local.
|
||||||
|
if not cluster_spec.as_dict():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||||
|
|
||||||
|
# Only one evaluator, so no need to set collective leader.
|
||||||
|
if task_type == "evaluator":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Use chief if chief is in the cluster.
|
||||||
|
if "chief" in cluster_spec.jobs:
|
||||||
|
return "/job:chief/replica:0/task:0"
|
||||||
|
|
||||||
|
# Use worker 0 if no chief job.
|
||||||
|
assert "worker" in cluster_spec.jobs
|
||||||
|
return "/job:worker/replica:0/task:0"
|
||||||
|
|
||||||
|
|
||||||
def worker_count(cluster_spec, task_type):
|
def worker_count(cluster_spec, task_type):
|
||||||
"""Returns the number of workers in the cluster."""
|
"""Returns the number of workers in the cluster."""
|
||||||
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
|
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
|
||||||
|
@ -102,6 +102,14 @@ class IsChiefTest(test.TestCase):
|
|||||||
ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
|
ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
|
||||||
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
||||||
|
|
||||||
|
def testEvaluatorIsChief(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"evaluator": ["127.0.0.1:2019"]
|
||||||
|
}
|
||||||
|
self.assertTrue(multi_worker_util.is_chief(cluster_spec, "evaluator", 0))
|
||||||
|
|
||||||
|
|
||||||
class NumWorkersTest(test.TestCase):
|
class NumWorkersTest(test.TestCase):
|
||||||
|
|
||||||
@ -192,5 +200,42 @@ class IdInClusterTest(test.TestCase):
|
|||||||
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
|
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class CollectiveLeaderTest(test.TestCase):
|
||||||
|
|
||||||
|
def testChiefAsLeader(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.collective_leader(cluster_spec, "worker", 0),
|
||||||
|
"/job:chief/replica:0/task:0")
|
||||||
|
|
||||||
|
def testWorkerAsLeader(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.collective_leader(cluster_spec, "worker", 1),
|
||||||
|
"/job:worker/replica:0/task:0")
|
||||||
|
|
||||||
|
def testLeaderForEvaluator(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"],
|
||||||
|
"evaluator": ["127.0.0.1:2019"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.collective_leader(cluster_spec, "evaluator", 0), "")
|
||||||
|
|
||||||
|
def testLocalLeader(self):
|
||||||
|
cluster_spec = {}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.collective_leader(cluster_spec, None, 0), "")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -502,11 +502,13 @@ class ParameterServerStrategyExtended(
|
|||||||
assert self._task_id is not None
|
assert self._task_id is not None
|
||||||
|
|
||||||
# The device filters prevent communication between workers.
|
# The device filters prevent communication between workers.
|
||||||
if self._task_type not in ["chief", "worker"]:
|
|
||||||
return updated_config
|
|
||||||
del updated_config.device_filters[:]
|
del updated_config.device_filters[:]
|
||||||
updated_config.device_filters.extend(
|
if self._task_type in ["chief", "worker"]:
|
||||||
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
|
updated_config.device_filters.extend(
|
||||||
|
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
|
||||||
|
elif self._task_type == "evaluator":
|
||||||
|
updated_config.device_filters.append(
|
||||||
|
"/job:%s/task:%d" % (self._task_type, self._task_id))
|
||||||
return updated_config
|
return updated_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
x
Reference in New Issue
Block a user