Pass TPU worker host:port pairs to trace() from capture_tpu_profile
.
PiperOrigin-RevId: 340943636 Change-Id: I8347dcdf63d38ee0368ad521082441555e2b3222
This commit is contained in:
parent
6a115c01e2
commit
7540f9ff5a
tensorflow/python
@ -104,13 +104,15 @@ def trace(service_addr,
|
||||
|
||||
```python
|
||||
# Send gRPC request to a TPU pod to collect a trace of your model on
|
||||
# multipleTPUs. A profiler service has been started in all the TPU workers
|
||||
# at theport 8466.
|
||||
# multiple TPUs. A profiler service has been started in all the TPU workers
|
||||
# at the port 8466.
|
||||
# E.g. your TPU IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you want
|
||||
# to profile for 2 seconds.
|
||||
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
|
||||
'gs://your_tb_dir',
|
||||
2000, '10.0.0.2,10.0.0.3,10.0.0.4')
|
||||
tf.profiler.experimental.client.trace(
|
||||
'grpc://10.0.0.2:8466',
|
||||
'gs://your_tb_dir',
|
||||
2000,
|
||||
'10.0.0.2:8466,10.0.0.3:8466,10.0.0.4:8466')
|
||||
```
|
||||
|
||||
Launch TensorBoard and point it to the same logdir you provided to this API.
|
||||
|
@ -55,8 +55,8 @@ flags.DEFINE_string(
|
||||
'localhost:8466, you must specify either this flag or --tpu.')
|
||||
flags.DEFINE_string(
|
||||
'workers_list', None, 'The list of worker TPUs that we are about to profile'
|
||||
' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or '
|
||||
'--service_addr to profile a subset of tpu nodes. You can also use only'
|
||||
' e.g. 10.0.1.2:8466, 10.0.1.3:8466. You can specify this flag with --tpu '
|
||||
'or --service_addr to profile a subset of tpu nodes. You can also use only'
|
||||
'--tpu and leave this flag unspecified to profile all the tpus.')
|
||||
flags.DEFINE_string(
|
||||
'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
|
||||
@ -83,17 +83,17 @@ flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
|
||||
|
||||
|
||||
def get_workers_list(cluster_resolver):
|
||||
"""Returns a comma separated list of TPU worker IP addresses.
|
||||
"""Returns a comma separated list of TPU worker host:port pairs.
|
||||
|
||||
Gets cluster_spec from cluster_resolver. Use the worker's task indices to
|
||||
obtain and return a list of ip addresses.
|
||||
obtain and return a list of host:port pairs.
|
||||
|
||||
Args:
|
||||
cluster_resolver: TensorFlow TPUClusterResolver instance.
|
||||
|
||||
Returns:
|
||||
A string of comma separated list of IP addresses. For example:
|
||||
'10.2.0.1,10.2.0.2,10.2.0.3,10.2.0.4'
|
||||
A string of comma separated list of host:port pairs. For example:
|
||||
'10.2.0.1:8466,10.2.0.2:8466,10.2.0.3:8466,10.2.0.4:8466'
|
||||
|
||||
Raises:
|
||||
UnavailableError: cluster_resolver doesn't contain a valid cluster_spec.
|
||||
@ -106,7 +106,7 @@ def get_workers_list(cluster_resolver):
|
||||
'Cluster spec not found, your client must run in GCE environment.')
|
||||
task_indices = cluster_spec.task_indices(worker_job_name)
|
||||
workers_list = [
|
||||
cluster_spec.task_address(worker_job_name, i).split(':')[0]
|
||||
cluster_spec.task_address(worker_job_name, i).replace(':8470', ':8466')
|
||||
for i in task_indices
|
||||
]
|
||||
return ','.join(workers_list)
|
||||
|
Loading…
Reference in New Issue
Block a user