switch capture_tpu_profile to new api of profiler_client. because some options is dropped, therefore two flags are deprecated.
also allow it specify host trace level. PiperOrigin-RevId: 312739183 Change-Id: I4e4712441877e697956d539055e333baf8a8d7bd
This commit is contained in:
parent
7315b275c0
commit
d0a5894b58
tensorflow/python/tpu/profiler
@ -38,7 +38,8 @@ py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:versions",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:profiler_client",
|
||||
"//tensorflow/python/profiler:profiler_client",
|
||||
"//tensorflow/python/profiler:profiler_v2",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
],
|
||||
|
@ -25,7 +25,8 @@ from absl import flags
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||
from tensorflow.python.eager import profiler_client
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -65,9 +66,10 @@ flags.DEFINE_integer('duration_ms', 0,
|
||||
flags.DEFINE_integer(
|
||||
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
|
||||
'event is collected.')
|
||||
flags.DEFINE_boolean('include_dataset_ops', True,
|
||||
'Set to false to profile longer TPU '
|
||||
'device traces.')
|
||||
flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.')
|
||||
flags.DEFINE_integer(
|
||||
'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity '
|
||||
' of the TraceMe event being collected.')
|
||||
|
||||
# Monitoring parameters
|
||||
flags.DEFINE_integer(
|
||||
@ -77,8 +79,7 @@ flags.DEFINE_integer(
|
||||
flags.DEFINE_integer(
|
||||
'num_queries', 100,
|
||||
'This script will run monitoring for num_queries before it stops.')
|
||||
flags.DEFINE_boolean('display_timestamp', False,
|
||||
'Set to true to display timestamp in monitoring results.')
|
||||
flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
|
||||
|
||||
|
||||
def get_workers_list(cluster_resolver):
|
||||
@ -111,8 +112,7 @@ def get_workers_list(cluster_resolver):
|
||||
return ','.join(workers_list)
|
||||
|
||||
|
||||
def monitoring_helper(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp, num_queries):
|
||||
def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries):
|
||||
"""Helper function to print monitoring results.
|
||||
|
||||
Helper function to print monitoring results for num_queries times.
|
||||
@ -122,15 +122,13 @@ def monitoring_helper(service_addr, duration_ms, monitoring_level,
|
||||
duration_ms: Duration of one monitoring sample in milliseconds.
|
||||
monitoring_level: An integer between 1 and 2. Level 2 is more verbose than
|
||||
level 1 and shows more metrics.
|
||||
display_timestamp: Set to true to display timestamp in monitoring.
|
||||
num_queries: Number of monitoring samples to collect.
|
||||
"""
|
||||
if monitoring_level <= 0 or monitoring_level > 2:
|
||||
sys.exit('Please choose a monitoring level between 1 and 2.')
|
||||
|
||||
for query in range(0, num_queries):
|
||||
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp)
|
||||
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level)
|
||||
print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res)
|
||||
|
||||
|
||||
@ -144,8 +142,8 @@ def main(unused_argv=None):
|
||||
print('TensorFlow version %s detected' % tf_version)
|
||||
print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
|
||||
|
||||
if LooseVersion(tf_version) < LooseVersion('1.14.0'):
|
||||
sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.')
|
||||
if LooseVersion(tf_version) < LooseVersion('2.2.0'):
|
||||
sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.')
|
||||
|
||||
if not FLAGS.service_addr and not FLAGS.tpu:
|
||||
sys.exit('You must specify either --service_addr or --tpu.')
|
||||
@ -184,7 +182,7 @@ def main(unused_argv=None):
|
||||
FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries,
|
||||
' time(s).')
|
||||
monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
|
||||
FLAGS.display_timestamp, FLAGS.num_queries)
|
||||
FLAGS.num_queries)
|
||||
else:
|
||||
if not FLAGS.logdir:
|
||||
sys.exit('You must specify either --logdir or --monitoring_level.')
|
||||
@ -193,11 +191,16 @@ def main(unused_argv=None):
|
||||
gfile.MakeDirs(FLAGS.logdir)
|
||||
|
||||
try:
|
||||
profiler_client.start_tracing(service_addr,
|
||||
os.path.expanduser(FLAGS.logdir),
|
||||
duration_ms, workers_list,
|
||||
FLAGS.include_dataset_ops,
|
||||
FLAGS.num_tracing_attempts)
|
||||
if LooseVersion(tf_version) < LooseVersion('2.3.0'):
|
||||
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
|
||||
duration_ms, workers_list,
|
||||
FLAGS.num_tracing_attempts)
|
||||
else:
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=FLAGS.host_tracer_level)
|
||||
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
|
||||
duration_ms, workers_list,
|
||||
FLAGS.num_tracing_attempts, options)
|
||||
except errors.UnavailableError:
|
||||
sys.exit(0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user