Make task_type and task_id standard properties in tf.distribute cluster resolvers.
PiperOrigin-RevId: 317736970 Change-Id: Ia9c76462afc4c2fcc42a149960b50b2cbcafd482
This commit is contained in:
parent
e6ada6a6b4
commit
61c2e69663
|
@ -71,11 +71,22 @@ class ClusterResolver(object):
|
|||
workers. This will eventually allow us to automatically recover from
|
||||
underlying machine failures and scale TensorFlow worker clusters up and down.
|
||||
|
||||
Note to Implementors: In addition to these abstract methods, you must also
|
||||
implement the task_type, task_id, and rpc_layer attributes. You may choose
|
||||
to implement them either as properties with getters or setters or directly
|
||||
set the attributes. The task_type and task_id attributes are required by
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
Note to Implementors of `tf.distribute.cluster_resolver.ClusterResolver`
|
||||
subclass: In addition to these abstract methods, when task_type, task_id, and
|
||||
rpc_layer attributes are applicable, you should also implement them either as
|
||||
properties with getters or setters, or directly set the attributes
|
||||
`self._task_type`, `self._task_id`, or `self._rpc_layer` so the base class'
|
||||
getters and setters are used. See
|
||||
`tf.distribute.cluster_resolver.SimpleClusterResolver.__init__` for an
|
||||
example.
|
||||
|
||||
In general, multi-client tf.distribute strategies such as
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy` require task_type and
|
||||
task_id properties to be available in the `ClusterResolver` they are using. On
|
||||
the other hand, these concepts are not applicable in single-client strategies,
|
||||
such as `tf.distribute.experimental.TPUStrategy`, because the program is only
|
||||
expected to be run on one task, so there should not be a need to have code
|
||||
branches according to task type and task id.
|
||||
|
||||
- task_type is the name of the server's current named job (e.g. 'worker',
|
||||
'ps' in a distributed parameterized training job).
|
||||
|
@ -177,6 +188,106 @@ class ClusterResolver(object):
|
|||
"""
|
||||
return ''
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
"""Returns the task type this `ClusterResolver` indicates.
|
||||
|
||||
In TensorFlow distributed environment, each job may have an applicable
|
||||
task type. Valid task types in TensorFlow include
|
||||
'chief': a worker that is designated with more responsibility,
|
||||
'worker': a regular worker for training/evaluation,
|
||||
'ps': a parameter server, or
|
||||
'evaluator': an evaluator that evaluates the checkpoints for metrics.
|
||||
|
||||
See [Multi-worker configuration](
|
||||
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration)
|
||||
for more information about 'chief' and 'worker' task type, which are most
|
||||
commonly used.
|
||||
|
||||
Having access to such information is useful when user needs to run specific
|
||||
code according to task types. For example,
|
||||
|
||||
```python
|
||||
cluster_spec = tf.train.ClusterSpec({
|
||||
"ps": ["localhost:2222", "localhost:2223"],
|
||||
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
|
||||
})
|
||||
|
||||
# SimpleClusterResolver is used here for illustration; other cluster
|
||||
# resolvers may be used for other source of task type/id.
|
||||
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
|
||||
task_id=1)
|
||||
|
||||
...
|
||||
|
||||
if cluster_resolver.task_type == 'worker':
|
||||
# Perform something that's only applicable on workers. This block
|
||||
# will run on this particular instance since we've specified this task to
|
||||
# be a worker in above cluster resolver.
|
||||
elif cluster_resolver.task_type == 'ps':
|
||||
# Perform something that's only applicable on parameter servers. This
|
||||
# block will not run on this particular instance.
|
||||
```
|
||||
|
||||
Returns `None` if such information is not available or is not applicable
|
||||
in the current distributed environment, such as training with
|
||||
`tf.distribute.experimental.TPUStrategy`.
|
||||
|
||||
For more information, please see
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`'s class doc.
|
||||
"""
|
||||
return getattr(self, '_task_type', None)
|
||||
|
||||
@property
|
||||
def task_id(self):
|
||||
"""Returns the task id this `ClusterResolver` indicates.
|
||||
|
||||
In TensorFlow distributed environment, each job may have an applicable
|
||||
task id, which is the index of the instance within its task type. This is
|
||||
useful when user needs to run specific code according to task index. For
|
||||
example,
|
||||
|
||||
```python
|
||||
cluster_spec = tf.train.ClusterSpec({
|
||||
"ps": ["localhost:2222", "localhost:2223"],
|
||||
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
|
||||
})
|
||||
|
||||
# SimpleClusterResolver is used here for illustration; other cluster
|
||||
# resolvers may be used for other source of task type/id.
|
||||
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
|
||||
task_id=0)
|
||||
|
||||
...
|
||||
|
||||
if cluster_resolver.task_type == 'worker' and cluster_resolver.task_id == 0:
|
||||
# Perform something that's only applicable on 'worker' type, id 0. This
|
||||
# block will run on this particular instance since we've specified this
|
||||
# task to be a 'worker', id 0 in above cluster resolver.
|
||||
else:
|
||||
# Perform something that's only applicable on other ids. This block will
|
||||
# not run on this particular instance.
|
||||
```
|
||||
|
||||
Returns `None` if such information is not available or is not applicable
|
||||
in the current distributed environment, such as training with
|
||||
`tf.distribute.cluster_resolver.TPUClusterResolver`.
|
||||
|
||||
For more information, please see
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`'s class docstring.
|
||||
"""
|
||||
return getattr(self, '_task_id', None)
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
"""Setter of `task_type` property. See `task_type` property doc."""
|
||||
self._task_type = task_type
|
||||
|
||||
@task_id.setter
|
||||
def task_id(self, task_id):
|
||||
"""Setter of `task_id` property. See `task_type` property doc."""
|
||||
self._task_id = task_id
|
||||
|
||||
|
||||
@tf_export('distribute.cluster_resolver.SimpleClusterResolver')
|
||||
class SimpleClusterResolver(ClusterResolver):
|
||||
|
|
|
@ -310,5 +310,37 @@ class GCEClusterResolverTest(test.TestCase):
|
|||
"""
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
|
||||
def testSettingTaskTypeRaiseError(self):
|
||||
name_to_ip = [
|
||||
{
|
||||
'name': 'instance1',
|
||||
'ip': '10.1.2.3'
|
||||
},
|
||||
{
|
||||
'name': 'instance2',
|
||||
'ip': '10.2.3.4'
|
||||
},
|
||||
{
|
||||
'name': 'instance3',
|
||||
'ip': '10.3.4.5'
|
||||
},
|
||||
]
|
||||
|
||||
gce_cluster_resolver = GCEClusterResolver(
|
||||
project='test-project',
|
||||
zone='us-east1-d',
|
||||
instance_group='test-instance-group',
|
||||
task_type='testworker',
|
||||
port=8470,
|
||||
credentials=None,
|
||||
service=self.gen_standard_mock_service_client(name_to_ip))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'You cannot reset the task_type '
|
||||
'of the GCEClusterResolver after it has '
|
||||
'been created.'):
|
||||
gce_cluster_resolver.task_type = 'foobar'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
@ -6,6 +6,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], "
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], "
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "
|
||||
|
|
|
@ -6,6 +6,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], "
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], "
|
||||
|
|
|
@ -7,6 +7,14 @@ tf_class {
|
|||
name: "environment"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "
|
||||
|
|
Loading…
Reference in New Issue