switch capture_tpu_profile to new api of profiler_client. because some options is dropped, therefore two flags are deprecated.

also allow it specify host trace level.

PiperOrigin-RevId: 312739183
Change-Id: I4e4712441877e697956d539055e333baf8a8d7bd
This commit is contained in:
A. Unique TensorFlower 2020-05-21 14:15:43 -07:00 committed by TensorFlower Gardener
parent 7315b275c0
commit d0a5894b58
2 changed files with 24 additions and 20 deletions
tensorflow/python/tpu/profiler

View File

@ -38,7 +38,8 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:versions",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:profiler_client",
"//tensorflow/python/profiler:profiler_client",
"//tensorflow/python/profiler:profiler_v2",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],

View File

@ -25,7 +25,8 @@ from absl import flags
from distutils.version import LooseVersion
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.eager import profiler_client
from tensorflow.python.profiler import profiler_client
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.framework import errors
from tensorflow.python.framework import versions
from tensorflow.python.platform import gfile
@ -65,9 +66,10 @@ flags.DEFINE_integer('duration_ms', 0,
flags.DEFINE_integer(
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
'event is collected.')
flags.DEFINE_boolean('include_dataset_ops', True,
'Set to false to profile longer TPU '
'device traces.')
flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.')
flags.DEFINE_integer(
'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity '
' of the TraceMe event being collected.')
# Monitoring parameters
flags.DEFINE_integer(
@ -77,8 +79,7 @@ flags.DEFINE_integer(
flags.DEFINE_integer(
'num_queries', 100,
'This script will run monitoring for num_queries before it stops.')
flags.DEFINE_boolean('display_timestamp', False,
'Set to true to display timestamp in monitoring results.')
flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
def get_workers_list(cluster_resolver):
@ -111,8 +112,7 @@ def get_workers_list(cluster_resolver):
return ','.join(workers_list)
def monitoring_helper(service_addr, duration_ms, monitoring_level,
display_timestamp, num_queries):
def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries):
"""Helper function to print monitoring results.
Helper function to print monitoring results for num_queries times.
@ -122,15 +122,13 @@ def monitoring_helper(service_addr, duration_ms, monitoring_level,
duration_ms: Duration of one monitoring sample in milliseconds.
monitoring_level: An integer between 1 and 2. Level 2 is more verbose than
level 1 and shows more metrics.
display_timestamp: Set to true to display timestamp in monitoring.
num_queries: Number of monitoring samples to collect.
"""
if monitoring_level <= 0 or monitoring_level > 2:
sys.exit('Please choose a monitoring level between 1 and 2.')
for query in range(0, num_queries):
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level,
display_timestamp)
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level)
print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res)
@ -144,8 +142,8 @@ def main(unused_argv=None):
print('TensorFlow version %s detected' % tf_version)
print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
if LooseVersion(tf_version) < LooseVersion('1.14.0'):
sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.')
if LooseVersion(tf_version) < LooseVersion('2.2.0'):
sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.')
if not FLAGS.service_addr and not FLAGS.tpu:
sys.exit('You must specify either --service_addr or --tpu.')
@ -184,7 +182,7 @@ def main(unused_argv=None):
FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries,
' time(s).')
monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
FLAGS.display_timestamp, FLAGS.num_queries)
FLAGS.num_queries)
else:
if not FLAGS.logdir:
sys.exit('You must specify either --logdir or --monitoring_level.')
@ -193,11 +191,16 @@ def main(unused_argv=None):
gfile.MakeDirs(FLAGS.logdir)
try:
profiler_client.start_tracing(service_addr,
os.path.expanduser(FLAGS.logdir),
duration_ms, workers_list,
FLAGS.include_dataset_ops,
FLAGS.num_tracing_attempts)
if LooseVersion(tf_version) < LooseVersion('2.3.0'):
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
duration_ms, workers_list,
FLAGS.num_tracing_attempts)
else:
options = profiler.ProfilerOptions(
host_tracer_level=FLAGS.host_tracer_level)
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
duration_ms, workers_list,
FLAGS.num_tracing_attempts, options)
except errors.UnavailableError:
sys.exit(0)