Make task_type and task_id standard properties in tf.distribute cluster resolvers.

PiperOrigin-RevId: 317736970
Change-Id: Ia9c76462afc4c2fcc42a149960b50b2cbcafd482
This commit is contained in:
Rick Chao 2020-06-22 14:33:38 -07:00 committed by TensorFlower Gardener
parent e6ada6a6b4
commit 61c2e69663
10 changed files with 212 additions and 5 deletions

View File

@ -71,11 +71,22 @@ class ClusterResolver(object):
workers. This will eventually allow us to automatically recover from workers. This will eventually allow us to automatically recover from
underlying machine failures and scale TensorFlow worker clusters up and down. underlying machine failures and scale TensorFlow worker clusters up and down.
Note to Implementors: In addition to these abstract methods, you must also Note to Implementors of `tf.distribute.cluster_resolver.ClusterResolver`
implement the task_type, task_id, and rpc_layer attributes. You may choose subclass: In addition to these abstract methods, when task_type, task_id, and
to implement them either as properties with getters or setters or directly rpc_layer attributes are applicable, you should also implement them either as
set the attributes. The task_type and task_id attributes are required by properties with getters or setters, or directly set the attributes
`tf.distribute.experimental.MultiWorkerMirroredStrategy`. `self._task_type`, `self._task_id`, or `self._rpc_layer` so the base class'
getters and setters are used. See
`tf.distribute.cluster_resolver.SimpleClusterResolver.__init__` for an
example.
In general, multi-client tf.distribute strategies such as
`tf.distribute.experimental.MultiWorkerMirroredStrategy` require task_type and
task_id properties to be available in the `ClusterResolver` they are using. On
the other hand, these concepts are not applicable in single-client strategies,
such as `tf.distribute.experimental.TPUStrategy`, because the program is only
expected to be run on one task, so there should not be a need to have code
branches according to task type and task id.
- task_type is the name of the server's current named job (e.g. 'worker', - task_type is the name of the server's current named job (e.g. 'worker',
'ps' in a distributed parameterized training job). 'ps' in a distributed parameterized training job).
@ -177,6 +188,106 @@ class ClusterResolver(object):
""" """
return '' return ''
@property
def task_type(self):
"""Returns the task type this `ClusterResolver` indicates.
In TensorFlow distributed environment, each job may have an applicable
task type. Valid task types in TensorFlow include
'chief': a worker that is designated with more responsibility,
'worker': a regular worker for training/evaluation,
'ps': a parameter server, or
'evaluator': an evaluator that evaluates the checkpoints for metrics.
See [Multi-worker configuration](
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration)
for more information about 'chief' and 'worker' task type, which are most
commonly used.
Having access to such information is useful when user needs to run specific
code according to task types. For example,
```python
cluster_spec = tf.train.ClusterSpec({
"ps": ["localhost:2222", "localhost:2223"],
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
})
# SimpleClusterResolver is used here for illustration; other cluster
# resolvers may be used for other source of task type/id.
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
task_id=1)
...
if cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. This block
# will run on this particular instance since we've specified this task to
# be a worker in above cluster resolver.
elif cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. This
# block will not run on this particular instance.
```
Returns `None` if such information is not available or is not applicable
in the current distributed environment, such as training with
`tf.distribute.experimental.TPUStrategy`.
For more information, please see
`tf.distribute.cluster_resolver.ClusterResolver`'s class doc.
"""
return getattr(self, '_task_type', None)
@property
def task_id(self):
"""Returns the task id this `ClusterResolver` indicates.
In TensorFlow distributed environment, each job may have an applicable
task id, which is the index of the instance within its task type. This is
useful when user needs to run specific code according to task index. For
example,
```python
cluster_spec = tf.train.ClusterSpec({
"ps": ["localhost:2222", "localhost:2223"],
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
})
# SimpleClusterResolver is used here for illustration; other cluster
# resolvers may be used for other source of task type/id.
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
task_id=0)
...
if cluster_resolver.task_type == 'worker' and cluster_resolver.task_id == 0:
# Perform something that's only applicable on 'worker' type, id 0. This
# block will run on this particular instance since we've specified this
# task to be a 'worker', id 0 in above cluster resolver.
else:
# Perform something that's only applicable on other ids. This block will
# not run on this particular instance.
```
Returns `None` if such information is not available or is not applicable
in the current distributed environment, such as training with
`tf.distribute.cluster_resolver.TPUClusterResolver`.
For more information, please see
`tf.distribute.cluster_resolver.ClusterResolver`'s class docstring.
"""
return getattr(self, '_task_id', None)
@task_type.setter
def task_type(self, task_type):
"""Setter of `task_type` property. See `task_type` property doc."""
self._task_type = task_type
@task_id.setter
def task_id(self, task_id):
"""Setter of `task_id` property. See `task_type` property doc."""
self._task_id = task_id
@tf_export('distribute.cluster_resolver.SimpleClusterResolver') @tf_export('distribute.cluster_resolver.SimpleClusterResolver')
class SimpleClusterResolver(ClusterResolver): class SimpleClusterResolver(ClusterResolver):

View File

@ -310,5 +310,37 @@ class GCEClusterResolverTest(test.TestCase):
""" """
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
def testSettingTaskTypeRaiseError(self):
name_to_ip = [
{
'name': 'instance1',
'ip': '10.1.2.3'
},
{
'name': 'instance2',
'ip': '10.2.3.4'
},
{
'name': 'instance3',
'ip': '10.3.4.5'
},
]
gce_cluster_resolver = GCEClusterResolver(
project='test-project',
zone='us-east1-d',
instance_group='test-instance-group',
task_type='testworker',
port=8470,
credentials=None,
service=self.gen_standard_mock_service_client(name_to_ip))
with self.assertRaisesRegexp(
RuntimeError, 'You cannot reset the task_type '
'of the GCEClusterResolver after it has '
'been created.'):
gce_cluster_resolver.task_type = 'foobar'
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -6,6 +6,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
} }

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], " argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], "

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], " argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], "

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], " argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "

View File

@ -6,6 +6,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
} }

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], " argspec: "args=[\'self\', \'job_to_label_mapping\', \'tf_server_port\', \'rpc_layer\', \'override_client\'], varargs=None, keywords=None, defaults=[\'None\', \'8470\', \'grpc\', \'None\'], "

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], " argspec: "args=[\'self\', \'jobs\', \'port_base\', \'gpus_per_node\', \'gpus_per_task\', \'tasks_per_node\', \'auto_set_gpu\', \'rpc_layer\'], varargs=None, keywords=None, defaults=[\'None\', \'8888\', \'None\', \'None\', \'None\', \'True\', \'grpc\'], "

View File

@ -7,6 +7,14 @@ tf_class {
name: "environment" name: "environment"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "task_id"
mtype: "<type \'property\'>"
}
member {
name: "task_type"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], " argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "