allow profiler_client.py v2 (our rpc client) to provide ProfilerOptions,
v1/eager profiler_client.py will use default options. add device_tracer_level to ProfilerOptions and allow default values (no need to specify all options and in specific order) PiperOrigin-RevId: 306308283 Change-Id: Ia9a6a7cfde37142be79e314627b9d63993b6283c
This commit is contained in:
parent
d73ef3f567
commit
9f7f6e9faa
@ -47,7 +47,7 @@ def start_tracing(service_addr,
|
||||
UnavailableError: If no trace event is collected.
|
||||
"""
|
||||
_pywrap_profiler.trace(service_addr, logdir, worker_list, include_dataset_ops,
|
||||
duration_ms, num_tracing_attempts)
|
||||
duration_ms, num_tracing_attempts, {})
|
||||
|
||||
|
||||
@deprecated('2020-07-01', 'use `tf.profiler.experimental.client.monitor`.')
|
||||
|
@ -47,6 +47,25 @@ tensorflow::ProfileRequest MakeProfileRequest(
|
||||
return request;
|
||||
}
|
||||
|
||||
tensorflow::ProfileOptions GetOptions(const py::dict& opts) {
|
||||
tensorflow::ProfileOptions options =
|
||||
tensorflow::ProfilerSession::DefaultOptions();
|
||||
for (const auto& kw : opts) {
|
||||
std::string key = py::cast<std::string>(kw.first);
|
||||
if (key == "host_tracer_level") {
|
||||
options.set_host_tracer_level(py::cast<int>(kw.second));
|
||||
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
|
||||
} else if (key == "device_tracer_level") {
|
||||
options.set_device_tracer_level(py::cast<int>(kw.second));
|
||||
VLOG(1) << "device_tracer_level set to " << options.device_tracer_level();
|
||||
} else if (key == "python_tracer_level") {
|
||||
options.set_python_tracer_level(py::cast<int>(kw.second));
|
||||
VLOG(1) << "python_tracer_level set to " << options.python_tracer_level();
|
||||
}
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
class ProfilerSessionWrapper {
|
||||
public:
|
||||
void Start(const char* logdir, const py::dict& options) {
|
||||
@ -93,23 +112,6 @@ class ProfilerSessionWrapper {
|
||||
}
|
||||
|
||||
private:
|
||||
tensorflow::ProfileOptions GetOptions(const py::dict& opts) {
|
||||
tensorflow::ProfileOptions options =
|
||||
tensorflow::ProfilerSession::DefaultOptions();
|
||||
for (const auto& kw : opts) {
|
||||
std::string key = py::cast<std::string>(kw.first);
|
||||
if (key == "host_tracer_level") {
|
||||
options.set_host_tracer_level(py::cast<int>(kw.second));
|
||||
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
|
||||
} else if (key == "python_tracer_level") {
|
||||
options.set_python_tracer_level(py::cast<int>(kw.second));
|
||||
VLOG(1) << "enable_python_tracer set to "
|
||||
<< options.python_tracer_level();
|
||||
}
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::ProfilerSession> session_;
|
||||
tensorflow::string logdir_;
|
||||
};
|
||||
@ -134,11 +136,12 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
|
||||
|
||||
m.def("trace", [](const char* service_addr, const char* logdir,
|
||||
const char* worker_list, bool include_dataset_ops,
|
||||
int duration_ms, int num_tracing_attempts) {
|
||||
int duration_ms, int num_tracing_attempts,
|
||||
py::dict options) {
|
||||
tensorflow::Status status =
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||
tensorflow::ProfileOptions opts;
|
||||
tensorflow::ProfileOptions opts = GetOptions(options);
|
||||
opts.set_include_dataset_ops(include_dataset_ops);
|
||||
status =
|
||||
tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
|
@ -30,7 +30,8 @@ def trace(service_addr,
|
||||
logdir,
|
||||
duration_ms,
|
||||
worker_list='',
|
||||
num_tracing_attempts=3):
|
||||
num_tracing_attempts=3,
|
||||
options=None):
|
||||
"""Sends grpc requests to profiler server to perform on-demand profiling.
|
||||
|
||||
This method will block caller thread until it receives tracing result. This
|
||||
@ -48,6 +49,8 @@ def trace(service_addr,
|
||||
the current session (TPU only).
|
||||
num_tracing_attempts: Optional. Automatically retry N times when no trace
|
||||
event is collected (default 3).
|
||||
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
|
||||
profiler options.
|
||||
|
||||
Raises:
|
||||
UnavailableError: If no trace event is collected.
|
||||
@ -86,9 +89,10 @@ def trace(service_addr,
|
||||
Open your browser and go to localhost:6006/#profile to view profiling results.
|
||||
|
||||
"""
|
||||
opts = dict(options._asdict()) if options is not None else {}
|
||||
_pywrap_profiler.trace(
|
||||
_strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
|
||||
duration_ms, num_tracing_attempts)
|
||||
duration_ms, num_tracing_attempts, opts)
|
||||
|
||||
|
||||
@tf_export('profiler.experimental.client.monitor', v1=[])
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
|
||||
|
||||
class ProfilerClientTest(test_util.TensorFlowTestCase):
|
||||
@ -32,6 +33,13 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
profiler_client.trace('localhost:6006', tempfile.mkdtemp(), 2000)
|
||||
|
||||
def testStartTracing_ProcessInvalidAddressWithOptions(self):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
options = profiler.ProfilerOptions(
|
||||
host_tracer_level=3, device_tracer_level=0)
|
||||
profiler_client.trace(
|
||||
'localhost:6006', tempfile.mkdtemp(), 2000, options=options)
|
||||
|
||||
def testMonitor_ProcessInvalidAddress(self):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
profiler_client.monitor('localhost:6006', 2000)
|
||||
|
@ -48,8 +48,9 @@ _profiler_lock = threading.Lock()
|
||||
|
||||
@tf_export('profiler.experimental.ProfilerOptions', v1=[])
|
||||
class ProfilerOptions(
|
||||
collections.namedtuple('ProfilerOptions',
|
||||
['host_tracer_level', 'python_tracer_level'])):
|
||||
collections.namedtuple(
|
||||
'ProfilerOptions',
|
||||
['host_tracer_level', 'python_tracer_level', 'device_tracer_level'])):
|
||||
"""Options to control profiler behaviors.
|
||||
|
||||
A `tf.profiler.ProfilerOptions` hold the knobs to control tf.profiler's
|
||||
@ -60,8 +61,18 @@ class ProfilerOptions(
|
||||
2 => info, 3 => verbose. [default to 2]
|
||||
python_tracer_level: for enable python function call tracing, 1 => enable.
|
||||
0 => disable [default to 0]
|
||||
device_tracer_level: for adjust device (TPU/GPU) tracer level, 0 => disable
|
||||
1 => enabled. We may introduce fine-tuned level in the
|
||||
future. [default to 1]
|
||||
"""
|
||||
pass
|
||||
|
||||
def __new__(cls,
|
||||
host_tracer_level=2,
|
||||
python_tracer_level=0,
|
||||
device_tracer_level=1):
|
||||
return super(ProfilerOptions,
|
||||
cls).__new__(cls, host_tracer_level, python_tracer_level,
|
||||
device_tracer_level)
|
||||
|
||||
|
||||
@tf_export('profiler.experimental.start', v1=[])
|
||||
|
@ -3,6 +3,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.profiler.profiler_v2.ProfilerOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.profiler.profiler_v2.ProfilerOptions\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "device_tracer_level"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "host_tracer_level"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -6,6 +6,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "trace"
|
||||
argspec: "args=[\'service_addr\', \'logdir\', \'duration_ms\', \'worker_list\', \'num_tracing_attempts\'], varargs=None, keywords=None, defaults=[\'\', \'3\'], "
|
||||
argspec: "args=[\'service_addr\', \'logdir\', \'duration_ms\', \'worker_list\', \'num_tracing_attempts\', \'options\'], varargs=None, keywords=None, defaults=[\'\', \'3\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user