Make Distribution Strategy more friendly with evaluator job.

PiperOrigin-RevId: 238719371
This commit is contained in:
Yuefeng Zhou 2019-03-15 15:13:01 -07:00 committed by TensorFlower Gardener
parent 32d8905c37
commit c0886e70d1
7 changed files with 128 additions and 30 deletions

View File

@ -251,7 +251,7 @@ class DistributeCoordinatorIntegrationTest(
def _get_strategy_object(self, strategy_cls): def _get_strategy_object(self, strategy_cls):
if strategy_cls == mirrored_strategy.CoreMirroredStrategy: if strategy_cls == mirrored_strategy.CoreMirroredStrategy:
return strategy_cls(mirrored_strategy.all_local_devices()) return strategy_cls()
else: else:
return strategy_cls(num_gpus_per_worker=context.num_gpus()) return strategy_cls(num_gpus_per_worker=context.num_gpus())
@ -269,6 +269,7 @@ class DistributeCoordinatorIntegrationTest(
mirrored_strategy.MirroredStrategy, mirrored_strategy.MirroredStrategy,
mirrored_strategy.CoreMirroredStrategy, mirrored_strategy.CoreMirroredStrategy,
parameter_server_strategy.ParameterServerStrategy, parameter_server_strategy.ParameterServerStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
], ],
required_gpus=[0, 1])) required_gpus=[0, 1]))
def test_complete_flow_standalone_client(self, train_distribute_cls, def test_complete_flow_standalone_client(self, train_distribute_cls,
@ -371,9 +372,11 @@ class DistributeCoordinatorIntegrationTest(
parameter_server_strategy.ParameterServerStrategy, parameter_server_strategy.ParameterServerStrategy,
], ],
eval_distribute_cls=[ eval_distribute_cls=[
None, mirrored_strategy.MirroredStrategy, None,
mirrored_strategy.MirroredStrategy,
mirrored_strategy.CoreMirroredStrategy, mirrored_strategy.CoreMirroredStrategy,
parameter_server_strategy.ParameterServerStrategy, parameter_server_strategy.ParameterServerStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
], ],
required_gpus=[0, 1])) required_gpus=[0, 1]))
def test_complete_flow_independent_worker_between_graph( def test_complete_flow_independent_worker_between_graph(

View File

@ -72,7 +72,8 @@ def _create_cluster(num_workers,
has_eval=False, has_eval=False,
protocol='grpc', protocol='grpc',
worker_config=None, worker_config=None,
ps_config=None): ps_config=None,
eval_config=None):
"""Creates and starts local servers and returns the cluster_spec dict.""" """Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error: if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type raise _portpicker_import_error # pylint: disable=raising-bad-type
@ -124,7 +125,7 @@ def _create_cluster(num_workers,
job_name='evaluator', job_name='evaluator',
protocol=protocol, protocol=protocol,
task_index=0, task_index=0,
config=worker_config, config=eval_config,
start=True) start=True)
return cluster_dict return cluster_dict
@ -153,6 +154,9 @@ def create_in_process_cluster(num_workers,
ps_config = config_pb2.ConfigProto() ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0 ps_config.device_count['GPU'] = 0
eval_config = config_pb2.ConfigProto()
eval_config.experimental.collective_group_leader = ''
# Create in-process servers. Once an in-process tensorflow server is created, # Create in-process servers. Once an in-process tensorflow server is created,
# there is no way to terminate it. So we create one cluster per test process. # there is no way to terminate it. So we create one cluster per test process.
# We could've started the server in another process, we could then kill that # We could've started the server in another process, we could then kill that
@ -169,6 +173,7 @@ def create_in_process_cluster(num_workers,
has_eval=has_eval, has_eval=has_eval,
worker_config=worker_config, worker_config=worker_config,
ps_config=ps_config, ps_config=ps_config,
eval_config=eval_config,
protocol='grpc') protocol='grpc')

View File

@ -165,16 +165,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
task_id = cluster_resolver.task_id task_id = cluster_resolver.task_id
if task_type is None or task_id is None: if task_type is None or task_id is None:
raise ValueError("When `cluster_spec` is given, you must also specify " raise ValueError("When `cluster_spec` is given, you must also specify "
"`task_type` and `task_id` in the `cluster_resolver`.") "`task_type` and `task_id`.")
if task_type not in ("chief", "worker"):
raise ValueError(
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\"." % task_type)
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
if not self._num_workers: if not self._num_workers:
raise ValueError("No `worker` or `chief` tasks can be found in " raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
"`cluster_spec`.") "in `cluster_spec`.")
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
task_id) task_id)
@ -410,15 +406,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
# Collective group leader is needed for collective ops to coordinate # Collective group leader is needed for collective ops to coordinate
# workers. # workers.
if "chief" in self._cluster_spec.jobs: updated_config.experimental.collective_group_leader = (
updated_config.experimental.collective_group_leader = ( multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
"/job:chief/replica:0/task:0") self._task_id))
else:
if "worker" not in self._cluster_spec.jobs:
raise ValueError(
"You must have `chief` or `worker` jobs in the `cluster_spec`.")
updated_config.experimental.collective_group_leader = (
"/job:worker/replica:0/task:0")
# The device filters prevent communication between workers. # The device filters prevent communication between workers.
del updated_config.device_filters[:] del updated_config.device_filters[:]

View File

@ -367,6 +367,9 @@ def _split_cluster_for_evaluator(cluster_spec, task_type):
# distribution strategies and as a result ops in the evalauator task may have # distribution strategies and as a result ops in the evalauator task may have
# unspecified devices. Those ops may end up on other tasks if we don't split # unspecified devices. Those ops may end up on other tasks if we don't split
# the cluster. # the cluster.
# Note: if you bypass distribute coordinator and bring the cluster yourself,
# you can equivalently set device filters to split clusters. This is already
# done by distribution strategy's `update_config_proto` method.
new_cluster_spec = multi_worker_util.normalize_cluster_spec( new_cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_spec).as_dict() cluster_spec).as_dict()
if task_type == _TaskType.EVALUATOR: if task_type == _TaskType.EVALUATOR:

View File

@ -49,11 +49,13 @@ def normalize_cluster_spec(cluster_spec):
def _validate_cluster_spec(cluster_spec, task_type, task_id): def _validate_cluster_spec(cluster_spec, task_type, task_id):
"""Validates `cluster_spec`. """Validates `cluster_spec`.
It checks It checks:
1) whether there is such a task type as `task_type` in the 1) task type is one of "chief", "worker" or "evaluator".
`cluster_spec`. 2) whether there is such a task type as `task_type` in the `cluster_spec`.
2) whether there is at most one "chief" job. 3) whether there is at most one "chief" job.
3) whether the `task_id` is smaller than the number of `task_type`. 4) whether there is at most one "evaluator" job.
5) whether the `task_id` is smaller than the number of tasks for that
particular `task_type`.
Args: Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
@ -63,10 +65,20 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
ValueError: if `cluster_spec` fails any check. ValueError: if `cluster_spec` fails any check.
""" """
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
if task_type not in ("chief", "worker", "evaluator", "ps"):
raise ValueError(
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\", \"evaluator\" and \"ps\"." % task_type)
if task_type and task_type not in cluster_spec: if task_type and task_type not in cluster_spec:
raise ValueError("`task_type` %r not found in cluster_spec." % task_type) raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
if len(cluster_spec.get("chief", [])) > 1: if len(cluster_spec.get("chief", [])) > 1:
raise ValueError("There must be at most one 'chief' job.") raise ValueError("There must be at most one 'chief' job.")
if len(cluster_spec.get("evaluator", [])) > 1:
raise ValueError("There must be at most one 'evaluator' job.")
if task_id >= len(cluster_spec[task_type]): if task_id >= len(cluster_spec[task_type]):
raise ValueError( raise ValueError(
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
@ -75,6 +87,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
def is_chief(cluster_spec, task_type, task_id): def is_chief(cluster_spec, task_type, task_id):
"""Returns whether the given task is chief in the cluster. """Returns whether the given task is chief in the cluster.
Since there is at most one evaluator and the evaluator itself should be
independent of the training cluster, the evaluator job is also a chief job on
its own.
Args: Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
cluster configurations. cluster configurations.
@ -91,7 +107,7 @@ def is_chief(cluster_spec, task_type, task_id):
_validate_cluster_spec(cluster_spec, task_type, task_id) _validate_cluster_spec(cluster_spec, task_type, task_id)
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
if task_type == "chief": if task_type == "chief" or task_type == "evaluator":
return True return True
# If chief not in the cluster_spec, use the first worker as chief. This is # If chief not in the cluster_spec, use the first worker as chief. This is
@ -101,6 +117,40 @@ def is_chief(cluster_spec, task_type, task_id):
return False return False
def collective_leader(cluster_spec, task_type, task_id):
"""Return the job name for the leader of for collective ops.
Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
cluster configurations.
task_type: the task type in the cluster.
task_id: the task id in the cluster.
Returns:
a string indicating the leader job name or empty string if no need to set
leader job.
"""
cluster_spec = normalize_cluster_spec(cluster_spec)
# No need to set collective leader for local.
if not cluster_spec.as_dict():
return ""
_validate_cluster_spec(cluster_spec, task_type, task_id)
# Only one evaluator, so no need to set collective leader.
if task_type == "evaluator":
return ""
# Use chief if chief is in the cluster.
if "chief" in cluster_spec.jobs:
return "/job:chief/replica:0/task:0"
# Use worker 0 if no chief job.
assert "worker" in cluster_spec.jobs
return "/job:worker/replica:0/task:0"
def worker_count(cluster_spec, task_type): def worker_count(cluster_spec, task_type):
"""Returns the number of workers in the cluster.""" """Returns the number of workers in the cluster."""
_validate_cluster_spec(cluster_spec, task_type, task_id=0) _validate_cluster_spec(cluster_spec, task_type, task_id=0)

View File

@ -102,6 +102,14 @@ class IsChiefTest(test.TestCase):
ValueError, "The `task_id` 2 exceeds the maximum id of worker."): ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
multi_worker_util.is_chief(cluster_spec, "worker", 2) multi_worker_util.is_chief(cluster_spec, "worker", 2)
def testEvaluatorIsChief(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"evaluator": ["127.0.0.1:2019"]
}
self.assertTrue(multi_worker_util.is_chief(cluster_spec, "evaluator", 0))
class NumWorkersTest(test.TestCase): class NumWorkersTest(test.TestCase):
@ -192,5 +200,42 @@ class IdInClusterTest(test.TestCase):
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0) multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
class CollectiveLeaderTest(test.TestCase):
def testChiefAsLeader(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.collective_leader(cluster_spec, "worker", 0),
"/job:chief/replica:0/task:0")
def testWorkerAsLeader(self):
cluster_spec = {
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.collective_leader(cluster_spec, "worker", 1),
"/job:worker/replica:0/task:0")
def testLeaderForEvaluator(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"],
"evaluator": ["127.0.0.1:2019"]
}
self.assertEqual(
multi_worker_util.collective_leader(cluster_spec, "evaluator", 0), "")
def testLocalLeader(self):
cluster_spec = {}
self.assertEqual(
multi_worker_util.collective_leader(cluster_spec, None, 0), "")
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -502,11 +502,13 @@ class ParameterServerStrategyExtended(
assert self._task_id is not None assert self._task_id is not None
# The device filters prevent communication between workers. # The device filters prevent communication between workers.
if self._task_type not in ["chief", "worker"]:
return updated_config
del updated_config.device_filters[:] del updated_config.device_filters[:]
updated_config.device_filters.extend( if self._task_type in ["chief", "worker"]:
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) updated_config.device_filters.extend(
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
elif self._task_type == "evaluator":
updated_config.device_filters.append(
"/job:%s/task:%d" % (self._task_type, self._task_id))
return updated_config return updated_config
@property @property