Pass TPU worker host:port pairs to trace() from capture_tpu_profile.

PiperOrigin-RevId: 340943636
Change-Id: I8347dcdf63d38ee0368ad521082441555e2b3222
This commit is contained in:
Will Cromar 2020-11-05 15:50:49 -08:00 committed by TensorFlower Gardener
parent 6a115c01e2
commit 7540f9ff5a
2 changed files with 14 additions and 12 deletions
tensorflow/python

View File

@ -104,13 +104,15 @@ def trace(service_addr,
```python
# Send gRPC request to a TPU pod to collect a trace of your model on
# multipleTPUs. A profiler service has been started in all the TPU workers
# at theport 8466.
# multiple TPUs. A profiler service has been started in all the TPU workers
# at the port 8466.
# E.g. your TPU IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you want
# to profile for 2 seconds.
tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466',
'gs://your_tb_dir',
2000, '10.0.0.2,10.0.0.3,10.0.0.4')
tf.profiler.experimental.client.trace(
'grpc://10.0.0.2:8466',
'gs://your_tb_dir',
2000,
'10.0.0.2:8466,10.0.0.3:8466,10.0.0.4:8466')
```
Launch TensorBoard and point it to the same logdir you provided to this API.

View File

@ -55,8 +55,8 @@ flags.DEFINE_string(
'localhost:8466, you must specify either this flag or --tpu.')
flags.DEFINE_string(
'workers_list', None, 'The list of worker TPUs that we are about to profile'
' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or '
'--service_addr to profile a subset of tpu nodes. You can also use only'
' e.g. 10.0.1.2:8466, 10.0.1.3:8466. You can specify this flag with --tpu '
'or --service_addr to profile a subset of tpu nodes. You can also use only'
'--tpu and leave this flag unspecified to profile all the tpus.')
flags.DEFINE_string(
'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
@ -83,17 +83,17 @@ flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
def get_workers_list(cluster_resolver):
"""Returns a comma separated list of TPU worker IP addresses.
"""Returns a comma separated list of TPU worker host:port pairs.
Gets cluster_spec from cluster_resolver. Use the worker's task indices to
obtain and return a list of ip addresses.
obtain and return a list of host:port pairs.
Args:
cluster_resolver: TensorFlow TPUClusterResolver instance.
Returns:
A string of comma separated list of IP addresses. For example:
'10.2.0.1,10.2.0.2,10.2.0.3,10.2.0.4'
A string of comma separated list of host:port pairs. For example:
'10.2.0.1:8466,10.2.0.2:8466,10.2.0.3:8466,10.2.0.4:8466'
Raises:
UnavailableError: cluster_resolver doesn't contain a valid cluster_spec.
@ -106,7 +106,7 @@ def get_workers_list(cluster_resolver):
'Cluster spec not found, your client must run in GCE environment.')
task_indices = cluster_spec.task_indices(worker_job_name)
workers_list = [
cluster_spec.task_address(worker_job_name, i).split(':')[0]
cluster_spec.task_address(worker_job_name, i).replace(':8470', ':8466')
for i in task_indices
]
return ','.join(workers_list)