From d0a5894b58be100c698a2f49d3371a7c5e273d2f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 21 May 2020 14:15:43 -0700 Subject: [PATCH] switch capture_tpu_profile to new api of profiler_client. because some options is dropped, therefore two flags are deprecated. also allow it specify host trace level. PiperOrigin-RevId: 312739183 Change-Id: I4e4712441877e697956d539055e333baf8a8d7bd --- tensorflow/python/tpu/profiler/BUILD | 3 +- .../tpu/profiler/capture_tpu_profile.py | 41 ++++++++++--------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/tpu/profiler/BUILD b/tensorflow/python/tpu/profiler/BUILD index b505262c6a2..84ffb4234c0 100644 --- a/tensorflow/python/tpu/profiler/BUILD +++ b/tensorflow/python/tpu/profiler/BUILD @@ -38,7 +38,8 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:versions", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", - "//tensorflow/python/eager:profiler_client", + "//tensorflow/python/profiler:profiler_client", + "//tensorflow/python/profiler:profiler_v2", "@absl_py//absl:app", "@absl_py//absl/flags", ], diff --git a/tensorflow/python/tpu/profiler/capture_tpu_profile.py b/tensorflow/python/tpu/profiler/capture_tpu_profile.py index f0d22027e4e..0068dc402c0 100644 --- a/tensorflow/python/tpu/profiler/capture_tpu_profile.py +++ b/tensorflow/python/tpu/profiler/capture_tpu_profile.py @@ -25,7 +25,8 @@ from absl import flags from distutils.version import LooseVersion from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver -from tensorflow.python.eager import profiler_client +from tensorflow.python.profiler import profiler_client +from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.framework import errors from tensorflow.python.framework import versions from tensorflow.python.platform import gfile @@ -65,9 +66,10 @@ flags.DEFINE_integer('duration_ms', 0, flags.DEFINE_integer( 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' 'event is collected.') -flags.DEFINE_boolean('include_dataset_ops', True, - 'Set to false to profile longer TPU ' - 'device traces.') +flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.') +flags.DEFINE_integer( + 'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity ' + ' of the TraceMe event being collected.') # Monitoring parameters flags.DEFINE_integer( @@ -77,8 +79,7 @@ flags.DEFINE_integer( flags.DEFINE_integer( 'num_queries', 100, 'This script will run monitoring for num_queries before it stops.') -flags.DEFINE_boolean('display_timestamp', False, - 'Set to true to display timestamp in monitoring results.') +flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.') def get_workers_list(cluster_resolver): @@ -111,8 +112,7 @@ def get_workers_list(cluster_resolver): return ','.join(workers_list) -def monitoring_helper(service_addr, duration_ms, monitoring_level, - display_timestamp, num_queries): +def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries): """Helper function to print monitoring results. Helper function to print monitoring results for num_queries times. @@ -122,15 +122,13 @@ def monitoring_helper(service_addr, duration_ms, monitoring_level, duration_ms: Duration of one monitoring sample in milliseconds. monitoring_level: An integer between 1 and 2. Level 2 is more verbose than level 1 and shows more metrics. - display_timestamp: Set to true to display timestamp in monitoring. num_queries: Number of monitoring samples to collect. """ if monitoring_level <= 0 or monitoring_level > 2: sys.exit('Please choose a monitoring level between 1 and 2.') for query in range(0, num_queries): - res = profiler_client.monitor(service_addr, duration_ms, monitoring_level, - display_timestamp) + res = profiler_client.monitor(service_addr, duration_ms, monitoring_level) print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res) @@ -144,8 +142,8 @@ def main(unused_argv=None): print('TensorFlow version %s detected' % tf_version) print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) - if LooseVersion(tf_version) < LooseVersion('1.14.0'): - sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.') + if LooseVersion(tf_version) < LooseVersion('2.2.0'): + sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.') if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') @@ -184,7 +182,7 @@ def main(unused_argv=None): FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries, ' time(s).') monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level, - FLAGS.display_timestamp, FLAGS.num_queries) + FLAGS.num_queries) else: if not FLAGS.logdir: sys.exit('You must specify either --logdir or --monitoring_level.') @@ -193,11 +191,16 @@ def main(unused_argv=None): gfile.MakeDirs(FLAGS.logdir) try: - profiler_client.start_tracing(service_addr, - os.path.expanduser(FLAGS.logdir), - duration_ms, workers_list, - FLAGS.include_dataset_ops, - FLAGS.num_tracing_attempts) + if LooseVersion(tf_version) < LooseVersion('2.3.0'): + profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), + duration_ms, workers_list, + FLAGS.num_tracing_attempts) + else: + options = profiler.ProfilerOptions( + host_tracer_level=FLAGS.host_tracer_level) + profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), + duration_ms, workers_list, + FLAGS.num_tracing_attempts, options) except errors.UnavailableError: sys.exit(0)