Remove capture_tpu_profile c++ binary and add a python wrapper.
PiperOrigin-RevId: 258892451
This commit is contained in:
parent
a355cfad30
commit
7b8cd8dce9
@ -1,27 +0,0 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
|
||||||
|
|
||||||
package(
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "version",
|
|
||||||
hdrs = ["version.h"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_binary(
|
|
||||||
name = "capture_tpu_profile",
|
|
||||||
srcs = [
|
|
||||||
"capture_tpu_profile.cc",
|
|
||||||
],
|
|
||||||
tags = ["no_windows"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
":version",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/platform/cloud:gcs_file_system",
|
|
||||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,146 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Usage: capture_tpu_profile --service_addr="localhost:8466" --logdir=/tmp/log
|
|
||||||
//
|
|
||||||
// Initiates a TPU profiling on the TPUProfiler service at service_addr,
|
|
||||||
// receives and dumps the profile data to a tensorboard log directory.
|
|
||||||
|
|
||||||
#include "tensorflow/contrib/tpu/profiler/version.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
|
||||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
Status MonitoringHelper(const string& service_addr, int duration_ms,
|
|
||||||
int monitoring_level, bool display_timestamp,
|
|
||||||
int num_queries) {
|
|
||||||
for (int query = 0; query < num_queries; ++query) {
|
|
||||||
string result;
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::profiler::client::Monitor(
|
|
||||||
service_addr, duration_ms, monitoring_level, display_timestamp,
|
|
||||||
&result));
|
|
||||||
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
|
||||||
<< "):\n\n"
|
|
||||||
<< result << std::flush;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
tensorflow::string FLAGS_service_addr;
|
|
||||||
tensorflow::string FLAGS_logdir;
|
|
||||||
tensorflow::string FLAGS_workers_list;
|
|
||||||
int FLAGS_duration_ms = 0;
|
|
||||||
int FLAGS_num_tracing_attempts = 3;
|
|
||||||
bool FLAGS_include_dataset_ops = true;
|
|
||||||
int FLAGS_monitoring_level = 0;
|
|
||||||
bool FLAGS_display_timestamp = false;
|
|
||||||
int FLAGS_num_queries = 100;
|
|
||||||
std::vector<tensorflow::Flag> flag_list = {
|
|
||||||
tensorflow::Flag("service_addr", &FLAGS_service_addr,
|
|
||||||
"Address of TPU profiler service e.g. localhost:8466"),
|
|
||||||
tensorflow::Flag("workers_list", &FLAGS_workers_list,
|
|
||||||
"The list of worker TPUs that we are about to profile "
|
|
||||||
"in the current session."),
|
|
||||||
tensorflow::Flag("logdir", &FLAGS_logdir,
|
|
||||||
"Path of TensorBoard log directory e.g. /tmp/tb_log, "
|
|
||||||
"gs://tb_bucket"),
|
|
||||||
tensorflow::Flag(
|
|
||||||
"duration_ms", &FLAGS_duration_ms,
|
|
||||||
"Duration of tracing or monitoring in ms. Default is 2000ms for "
|
|
||||||
"tracing and 1000ms for monitoring."),
|
|
||||||
tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts,
|
|
||||||
"Automatically retry N times when no trace event "
|
|
||||||
"is collected. Default is 3."),
|
|
||||||
tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
|
|
||||||
"Set to false to profile longer TPU device traces."),
|
|
||||||
tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level,
|
|
||||||
"Choose a monitoring level between 1 and 2 to monitor "
|
|
||||||
"your TPU job continuously. Level 2 is more verbose "
|
|
||||||
"than level 1 and shows more metrics."),
|
|
||||||
tensorflow::Flag("display_timestamp", &FLAGS_display_timestamp,
|
|
||||||
"Set to true to display timestamp in monitoring "
|
|
||||||
"results."),
|
|
||||||
tensorflow::Flag("num_queries", &FLAGS_num_queries,
|
|
||||||
"This script will run monitoring for num_queries before "
|
|
||||||
"it stops.")};
|
|
||||||
|
|
||||||
std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
|
||||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
|
||||||
if (!parse_ok || FLAGS_service_addr.empty() ||
|
|
||||||
(FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) {
|
|
||||||
// Fail if flags are not parsed correctly or service_addr not provided.
|
|
||||||
// Also, fail if neither logdir is provided (required for tracing) nor
|
|
||||||
// monitoring level is provided (required for monitoring).
|
|
||||||
std::cout << usage.c_str() << std::endl;
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) {
|
|
||||||
// Invalid monitoring level.
|
|
||||||
std::cout << usage.c_str() << std::endl;
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
tensorflow::Status status;
|
|
||||||
status =
|
|
||||||
tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cout << status.error_message() << std::endl;
|
|
||||||
std::cout << usage.c_str() << std::endl;
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
|
||||||
|
|
||||||
// Sets the minimum duration_ms, tracing attempts and num queries.
|
|
||||||
int duration_ms = std::max(FLAGS_duration_ms, 0);
|
|
||||||
if (duration_ms == 0) {
|
|
||||||
// If profiling duration was not set by user or set to a negative value, we
|
|
||||||
// set it to default values of 2000ms for tracing and 1000ms for monitoring.
|
|
||||||
duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000;
|
|
||||||
}
|
|
||||||
int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1);
|
|
||||||
int num_queries = std::max(FLAGS_num_queries, 1);
|
|
||||||
|
|
||||||
if (FLAGS_monitoring_level != 0) {
|
|
||||||
std::cout << "Since monitoring level is provided, profile "
|
|
||||||
<< FLAGS_service_addr << " for " << duration_ms
|
|
||||||
<< "ms and show metrics for " << num_queries << " time(s)."
|
|
||||||
<< std::endl;
|
|
||||||
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
|
|
||||||
FLAGS_monitoring_level,
|
|
||||||
FLAGS_display_timestamp, num_queries);
|
|
||||||
} else {
|
|
||||||
status = tensorflow::profiler::client::StartTracing(
|
|
||||||
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
|
|
||||||
FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!status.ok() && status.code() != tensorflow::error::Code::UNAVAILABLE) {
|
|
||||||
std::cout << status.error_message() << std::endl;
|
|
||||||
std::cout << usage.c_str() << std::endl;
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,2 +0,0 @@
|
|||||||
This package contains a Python wrapper around an pre-built C++ binary that is
|
|
||||||
used to profile Cloud TPU.
|
|
@ -1,137 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# =============================================================================
|
|
||||||
"""Wraps capture_tpu_profile binary."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from absl import flags
|
|
||||||
from distutils.version import LooseVersion
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
# Cloud TPU Cluster Resolvers
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'gcp_project', None,
|
|
||||||
'Project name for the Cloud TPU-enabled project. If not specified, we '
|
|
||||||
'will attempt to automatically detect the GCE project from metadata.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'tpu_zone',
|
|
||||||
None,
|
|
||||||
help='GCE zone where the Cloud TPU is located in. If not specified, we '
|
|
||||||
'will attempt to automatically detect the GCE project from metadata.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must '
|
|
||||||
'specify either this flag or --service_addr.')
|
|
||||||
|
|
||||||
# Tool specific parameters
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'service_addr', None, 'Address of TPU profiler service e.g. '
|
|
||||||
'localhost:8466, you must specify either this flag or --tpu.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'workers_list', None, 'The list of worker TPUs that we are about to profile'
|
|
||||||
' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or '
|
|
||||||
'--service_addr to profile a subset of tpu nodes. You can also use only'
|
|
||||||
'--tpu and leave this flag unspecified to profile all the tpus.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
|
|
||||||
'gs://tb_bucket')
|
|
||||||
flags.DEFINE_integer('duration_ms', 0,
|
|
||||||
'Duration of tracing or monitoring in ms.')
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
|
|
||||||
'event is collected.')
|
|
||||||
flags.DEFINE_boolean('include_dataset_ops', True,
|
|
||||||
'Set to false to profile longer TPU '
|
|
||||||
'device traces.')
|
|
||||||
|
|
||||||
# Monitoring parameters
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'monitoring_level', 0, 'Choose a monitoring level between '
|
|
||||||
'1 and 2 to monitor your TPU job continuously.')
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'num_queries', 100,
|
|
||||||
'This script will run monitoring for num_queries before it stops.')
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
EXECUTABLE = 'data/capture_tpu_profile'
|
|
||||||
JOB_NAME = 'worker'
|
|
||||||
|
|
||||||
|
|
||||||
def get_workers_list(cluster_resolver):
|
|
||||||
cluster_spec = cluster_resolver.cluster_spec()
|
|
||||||
task_indices = cluster_spec.task_indices(JOB_NAME)
|
|
||||||
workers_list = [
|
|
||||||
cluster_spec.task_address(JOB_NAME, i).split(':')[0] for i in task_indices
|
|
||||||
]
|
|
||||||
return ','.join(workers_list)
|
|
||||||
|
|
||||||
|
|
||||||
def run_main():
|
|
||||||
tf.app.run(main)
|
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv=None):
|
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
|
||||||
tf_version = tf.__version__
|
|
||||||
print('TensorFlow version %s detected' % tf_version)
|
|
||||||
|
|
||||||
if not FLAGS.service_addr and not FLAGS.tpu:
|
|
||||||
sys.exit('You must specify either --service_addr or --tpu.')
|
|
||||||
|
|
||||||
tpu_cluster_resolver = None
|
|
||||||
if FLAGS.service_addr:
|
|
||||||
if FLAGS.tpu:
|
|
||||||
tf.logging.warn('Both --service_addr and --tpu are set. Ignoring '
|
|
||||||
'--tpu and using --service_addr.')
|
|
||||||
service_addr = FLAGS.service_addr
|
|
||||||
else:
|
|
||||||
tpu_cluster_resolver = (
|
|
||||||
tf.contrib.cluster_resolver.TPUClusterResolver(
|
|
||||||
[FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
|
|
||||||
service_addr = tpu_cluster_resolver.get_master()
|
|
||||||
service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466')
|
|
||||||
|
|
||||||
workers_list = ''
|
|
||||||
if LooseVersion(tf_version) < LooseVersion('1.9'):
|
|
||||||
tf.logging.warn('Attempt to profile with legacy support under TensorFlow '
|
|
||||||
'version %s' % tf_version)
|
|
||||||
else:
|
|
||||||
if FLAGS.workers_list is not None:
|
|
||||||
workers_list = FLAGS.workers_list
|
|
||||||
elif tpu_cluster_resolver is not None:
|
|
||||||
workers_list = get_workers_list(tpu_cluster_resolver)
|
|
||||||
|
|
||||||
if not FLAGS.logdir and not FLAGS.monitoring_level:
|
|
||||||
sys.exit('logdir must be provided.')
|
|
||||||
executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE)
|
|
||||||
cmd = [executable_path]
|
|
||||||
if FLAGS.logdir is not None:
|
|
||||||
logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
|
|
||||||
cmd.append('--logdir=' + logdir)
|
|
||||||
cmd.append('--service_addr=' + service_addr)
|
|
||||||
cmd.append('--workers_list=' + workers_list)
|
|
||||||
cmd.append('--duration_ms=' + str(FLAGS.duration_ms))
|
|
||||||
cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts))
|
|
||||||
cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower())
|
|
||||||
cmd.append('--monitoring_level=' + str(FLAGS.monitoring_level))
|
|
||||||
cmd.append('--num_queries=' + str(FLAGS.num_queries))
|
|
||||||
subprocess.call(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run_main()
|
|
@ -1,21 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
=============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
|
|
||||||
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
|
|
||||||
|
|
||||||
#define TPU_PROFILER_VERSION "1.12.0"
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
|
|
@ -1,12 +1,10 @@
|
|||||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
|
||||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
|
||||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
|
||||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
|
||||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
|
||||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
|
||||||
tensorflow/contrib/mpi/BUILD
|
tensorflow/contrib/mpi/BUILD
|
||||||
tensorflow/stream_executor/build_defs.bzl
|
tensorflow/stream_executor/build_defs.bzl
|
||||||
tensorflow/python/autograph/core/config.py
|
tensorflow/python/autograph/core/config.py
|
||||||
|
tensorflow/python/tpu/profiler/pip_package/setup.py
|
||||||
|
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
|
||||||
|
tensorflow/python/tpu/profiler/pip_package/BUILD
|
||||||
|
tensorflow/python/tpu/profiler/pip_package/README
|
||||||
tensorflow/tools/ci_build/remote/BUILD
|
tensorflow/tools/ci_build/remote/BUILD
|
||||||
tensorflow/tools/pip_package/README
|
tensorflow/tools/pip_package/README
|
||||||
tensorflow/tools/pip_package/MANIFEST.in
|
tensorflow/tools/pip_package/MANIFEST.in
|
||||||
|
@ -26,11 +26,12 @@ py_library(
|
|||||||
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
|
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_binary(
|
py_library(
|
||||||
name = "capture_tpu_profile",
|
name = "capture_tpu_profile_lib",
|
||||||
srcs = ["capture_tpu_profile.py"],
|
srcs = [
|
||||||
main = "capture_tpu_profile.py",
|
"capture_tpu_profile.py",
|
||||||
python_version = "PY2",
|
"version.py",
|
||||||
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
@ -43,3 +44,14 @@ py_binary(
|
|||||||
"@absl_py//absl/flags",
|
"@absl_py//absl/flags",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "capture_tpu_profile_bin",
|
||||||
|
srcs = ["capture_tpu_profile.py"],
|
||||||
|
main = "capture_tpu_profile.py",
|
||||||
|
python_version = "PY2",
|
||||||
|
deps = [
|
||||||
|
":capture_tpu_profile_lib",
|
||||||
|
"@absl_py//absl/flags",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import profiler_client
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.tpu.profiler import version as profiler_version
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
@ -139,9 +140,10 @@ def main(unused_argv=None):
|
|||||||
logging.set_verbosity(logging.INFO)
|
logging.set_verbosity(logging.INFO)
|
||||||
tf_version = versions.__version__
|
tf_version = versions.__version__
|
||||||
print('TensorFlow version %s detected' % tf_version)
|
print('TensorFlow version %s detected' % tf_version)
|
||||||
|
print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
|
||||||
|
|
||||||
if LooseVersion(tf_version) < LooseVersion('1.14.1'):
|
if LooseVersion(tf_version) < LooseVersion('1.14.0'):
|
||||||
sys.exit('You must install tensorflow >= 1.14.1 to use this plugin.')
|
sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.')
|
||||||
|
|
||||||
if not FLAGS.service_addr and not FLAGS.tpu:
|
if not FLAGS.service_addr and not FLAGS.tpu:
|
||||||
sys.exit('You must specify either --service_addr or --tpu.')
|
sys.exit('You must specify either --service_addr or --tpu.')
|
||||||
@ -179,10 +181,14 @@ def main(unused_argv=None):
|
|||||||
else:
|
else:
|
||||||
if not FLAGS.logdir:
|
if not FLAGS.logdir:
|
||||||
sys.exit('logdir must be provided')
|
sys.exit('logdir must be provided')
|
||||||
|
try:
|
||||||
profiler_client.start_tracing(service_addr,
|
profiler_client.start_tracing(service_addr,
|
||||||
os.path.expanduser(FLAGS.logdir), duration_ms,
|
os.path.expanduser(FLAGS.logdir),
|
||||||
workers_list, FLAGS.include_dataset_ops,
|
duration_ms, workers_list,
|
||||||
|
FLAGS.include_dataset_ops,
|
||||||
FLAGS.num_tracing_attempts)
|
FLAGS.num_tracing_attempts)
|
||||||
|
except errors.UnavailableError:
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -10,6 +10,6 @@ sh_binary(
|
|||||||
srcs = ["build_pip_package.sh"],
|
srcs = ["build_pip_package.sh"],
|
||||||
data = [
|
data = [
|
||||||
"setup.py",
|
"setup.py",
|
||||||
"//tensorflow/contrib/tpu/profiler:capture_tpu_profile",
|
"//tensorflow/python/tpu/profiler:capture_tpu_profile_bin",
|
||||||
],
|
],
|
||||||
)
|
)
|
2
tensorflow/python/tpu/profiler/pip_package/README
Normal file
2
tensorflow/python/tpu/profiler/pip_package/README
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
This package contains a Python wrapper around Tensorflow Python APIs that are
|
||||||
|
used to profile Cloud TPU.
|
@ -17,9 +17,15 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
BINARY="bazel-bin/tensorflow/contrib/tpu/profiler/capture_tpu_profile"
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
|
sedi="sed -i ''"
|
||||||
|
else
|
||||||
|
sedi="sed -i"
|
||||||
|
fi
|
||||||
|
|
||||||
PACKAGE_NAME="cloud_tpu_profiler"
|
PACKAGE_NAME="cloud_tpu_profiler"
|
||||||
PIP_PACKAGE="tensorflow/contrib/tpu/profiler/pip_package"
|
PIP_PACKAGE="tensorflow/python/tpu/profiler/pip_package"
|
||||||
|
RUNFILES="bazel-bin/tensorflow/python/tpu/profiler/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow/python/tpu/profiler"
|
||||||
|
|
||||||
function main() {
|
function main() {
|
||||||
if [ $# -lt 1 ] ; then
|
if [ $# -lt 1 ] ; then
|
||||||
@ -32,16 +38,16 @@ function main() {
|
|||||||
|
|
||||||
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
|
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
|
||||||
|
|
||||||
if [ ! -f "${BINARY}" ]; then
|
|
||||||
echo "Could not find ${BINARY}. Did you run from the root of the build tree?"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
cp ${PIP_PACKAGE}/README ${TMPDIR}
|
cp ${PIP_PACKAGE}/README ${TMPDIR}
|
||||||
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
|
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
|
||||||
cp -R ${PIP_PACKAGE}/${PACKAGE_NAME} ${TMPDIR}
|
mkdir ${TMPDIR}/${PACKAGE_NAME}
|
||||||
mkdir ${TMPDIR}/${PACKAGE_NAME}/data
|
cp -a ${RUNFILES}/. ${TMPDIR}/${PACKAGE_NAME}/
|
||||||
cp ${BINARY} ${TMPDIR}/${PACKAGE_NAME}/data
|
|
||||||
|
# Fix the import statements to reflect the copied over path.
|
||||||
|
find ${TMPDIR}/${PACKAGE_NAME} -name \*.py |
|
||||||
|
xargs $sedi -e '
|
||||||
|
s/^from tensorflow.python.tpu.profiler/from '${PACKAGE_NAME}'/
|
||||||
|
'
|
||||||
echo $(ls $TMPDIR)
|
echo $(ls $TMPDIR)
|
||||||
|
|
||||||
pushd ${TMPDIR}
|
pushd ${TMPDIR}
|
@ -20,32 +20,25 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
_VERSION = '1.12.0'
|
from cloud_tpu_profiler.version import __version__
|
||||||
|
|
||||||
CONSOLE_SCRIPTS = [
|
CONSOLE_SCRIPTS = [
|
||||||
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
|
'capture_tpu_profile=cloud_tpu_profiler.capture_tpu_profile:run_main',
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='cloud_tpu_profiler',
|
name='cloud_tpu_profiler',
|
||||||
version=_VERSION.replace('-', ''),
|
version=__version__.replace('-', ''),
|
||||||
description='Trace and profile Cloud TPU performance',
|
description='Trace and profile Cloud TPU performance',
|
||||||
long_description='Tools for capture TPU profile',
|
long_description='Tools for capture TPU profile',
|
||||||
url='https://www.tensorflow.org/tfrc/',
|
url='https://www.tensorflow.org/tfrc/',
|
||||||
author='Google Inc.',
|
author='Google Inc.',
|
||||||
author_email='packages@tensorflow.org',
|
author_email='packages@tensorflow.org',
|
||||||
packages=['cloud_tpu_profiler'],
|
packages=['cloud_tpu_profiler'],
|
||||||
package_data={
|
|
||||||
'cloud_tpu_profiler': ['data/*'],
|
|
||||||
},
|
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': CONSOLE_SCRIPTS,
|
'console_scripts': CONSOLE_SCRIPTS,
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
# How mature is this project? Common values are
|
|
||||||
# 3 - Alpha
|
|
||||||
# 4 - Beta
|
|
||||||
# 5 - Production/Stable
|
|
||||||
'Development Status :: 5 - Production/Stable',
|
'Development Status :: 5 - Production/Stable',
|
||||||
'Intended Audience :: Developers',
|
'Intended Audience :: Developers',
|
||||||
'Intended Audience :: Education',
|
'Intended Audience :: Education',
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,8 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
"""Cloud TPU profiler."""
|
"""Cloud TPU profiler version information."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Cloud TPU profiler uses semantic versioning, see http://semver.org/.
|
||||||
|
# A version string consists of
|
||||||
|
# major_version.minor_version.patch_version-build_metadata.
|
||||||
|
__version__ = "1.14.1-a0"
|
Loading…
Reference in New Issue
Block a user