switch capture_tpu_profile to new api of profiler_client. because some options is dropped, therefore two flags are deprecated.
also allow it specify host trace level. PiperOrigin-RevId: 312739183 Change-Id: I4e4712441877e697956d539055e333baf8a8d7bd
This commit is contained in:
parent
7315b275c0
commit
d0a5894b58
@ -38,7 +38,8 @@ py_library(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:versions",
|
"//tensorflow/python:versions",
|
||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
"//tensorflow/python/eager:profiler_client",
|
"//tensorflow/python/profiler:profiler_client",
|
||||||
|
"//tensorflow/python/profiler:profiler_v2",
|
||||||
"@absl_py//absl:app",
|
"@absl_py//absl:app",
|
||||||
"@absl_py//absl/flags",
|
"@absl_py//absl/flags",
|
||||||
],
|
],
|
||||||
|
@ -25,7 +25,8 @@ from absl import flags
|
|||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||||
from tensorflow.python.eager import profiler_client
|
from tensorflow.python.profiler import profiler_client
|
||||||
|
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
@ -65,9 +66,10 @@ flags.DEFINE_integer('duration_ms', 0,
|
|||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
|
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
|
||||||
'event is collected.')
|
'event is collected.')
|
||||||
flags.DEFINE_boolean('include_dataset_ops', True,
|
flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.')
|
||||||
'Set to false to profile longer TPU '
|
flags.DEFINE_integer(
|
||||||
'device traces.')
|
'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity '
|
||||||
|
' of the TraceMe event being collected.')
|
||||||
|
|
||||||
# Monitoring parameters
|
# Monitoring parameters
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
@ -77,8 +79,7 @@ flags.DEFINE_integer(
|
|||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
'num_queries', 100,
|
'num_queries', 100,
|
||||||
'This script will run monitoring for num_queries before it stops.')
|
'This script will run monitoring for num_queries before it stops.')
|
||||||
flags.DEFINE_boolean('display_timestamp', False,
|
flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
|
||||||
'Set to true to display timestamp in monitoring results.')
|
|
||||||
|
|
||||||
|
|
||||||
def get_workers_list(cluster_resolver):
|
def get_workers_list(cluster_resolver):
|
||||||
@ -111,8 +112,7 @@ def get_workers_list(cluster_resolver):
|
|||||||
return ','.join(workers_list)
|
return ','.join(workers_list)
|
||||||
|
|
||||||
|
|
||||||
def monitoring_helper(service_addr, duration_ms, monitoring_level,
|
def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries):
|
||||||
display_timestamp, num_queries):
|
|
||||||
"""Helper function to print monitoring results.
|
"""Helper function to print monitoring results.
|
||||||
|
|
||||||
Helper function to print monitoring results for num_queries times.
|
Helper function to print monitoring results for num_queries times.
|
||||||
@ -122,15 +122,13 @@ def monitoring_helper(service_addr, duration_ms, monitoring_level,
|
|||||||
duration_ms: Duration of one monitoring sample in milliseconds.
|
duration_ms: Duration of one monitoring sample in milliseconds.
|
||||||
monitoring_level: An integer between 1 and 2. Level 2 is more verbose than
|
monitoring_level: An integer between 1 and 2. Level 2 is more verbose than
|
||||||
level 1 and shows more metrics.
|
level 1 and shows more metrics.
|
||||||
display_timestamp: Set to true to display timestamp in monitoring.
|
|
||||||
num_queries: Number of monitoring samples to collect.
|
num_queries: Number of monitoring samples to collect.
|
||||||
"""
|
"""
|
||||||
if monitoring_level <= 0 or monitoring_level > 2:
|
if monitoring_level <= 0 or monitoring_level > 2:
|
||||||
sys.exit('Please choose a monitoring level between 1 and 2.')
|
sys.exit('Please choose a monitoring level between 1 and 2.')
|
||||||
|
|
||||||
for query in range(0, num_queries):
|
for query in range(0, num_queries):
|
||||||
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level,
|
res = profiler_client.monitor(service_addr, duration_ms, monitoring_level)
|
||||||
display_timestamp)
|
|
||||||
print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res)
|
print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res)
|
||||||
|
|
||||||
|
|
||||||
@ -144,8 +142,8 @@ def main(unused_argv=None):
|
|||||||
print('TensorFlow version %s detected' % tf_version)
|
print('TensorFlow version %s detected' % tf_version)
|
||||||
print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
|
print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
|
||||||
|
|
||||||
if LooseVersion(tf_version) < LooseVersion('1.14.0'):
|
if LooseVersion(tf_version) < LooseVersion('2.2.0'):
|
||||||
sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.')
|
sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.')
|
||||||
|
|
||||||
if not FLAGS.service_addr and not FLAGS.tpu:
|
if not FLAGS.service_addr and not FLAGS.tpu:
|
||||||
sys.exit('You must specify either --service_addr or --tpu.')
|
sys.exit('You must specify either --service_addr or --tpu.')
|
||||||
@ -184,7 +182,7 @@ def main(unused_argv=None):
|
|||||||
FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries,
|
FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries,
|
||||||
' time(s).')
|
' time(s).')
|
||||||
monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
|
monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
|
||||||
FLAGS.display_timestamp, FLAGS.num_queries)
|
FLAGS.num_queries)
|
||||||
else:
|
else:
|
||||||
if not FLAGS.logdir:
|
if not FLAGS.logdir:
|
||||||
sys.exit('You must specify either --logdir or --monitoring_level.')
|
sys.exit('You must specify either --logdir or --monitoring_level.')
|
||||||
@ -193,11 +191,16 @@ def main(unused_argv=None):
|
|||||||
gfile.MakeDirs(FLAGS.logdir)
|
gfile.MakeDirs(FLAGS.logdir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profiler_client.start_tracing(service_addr,
|
if LooseVersion(tf_version) < LooseVersion('2.3.0'):
|
||||||
os.path.expanduser(FLAGS.logdir),
|
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
|
||||||
duration_ms, workers_list,
|
duration_ms, workers_list,
|
||||||
FLAGS.include_dataset_ops,
|
FLAGS.num_tracing_attempts)
|
||||||
FLAGS.num_tracing_attempts)
|
else:
|
||||||
|
options = profiler.ProfilerOptions(
|
||||||
|
host_tracer_level=FLAGS.host_tracer_level)
|
||||||
|
profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
|
||||||
|
duration_ms, workers_list,
|
||||||
|
FLAGS.num_tracing_attempts, options)
|
||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user