allow profiler_client.py v2 (our rpc client) to provide ProfilerOptions,

v1/eager profiler_client.py will use default options.
add device_tracer_level to ProfilerOptions and allow default values (no need to specify all options and in specific order)

PiperOrigin-RevId: 306308283
Change-Id: Ia9a6a7cfde37142be79e314627b9d63993b6283c
This commit is contained in:
A. Unique TensorFlower 2020-04-13 14:21:37 -07:00 committed by TensorFlower Gardener
parent d73ef3f567
commit 9f7f6e9faa
7 changed files with 56 additions and 26 deletions

View File

@ -47,7 +47,7 @@ def start_tracing(service_addr,
UnavailableError: If no trace event is collected.
"""
_pywrap_profiler.trace(service_addr, logdir, worker_list, include_dataset_ops,
duration_ms, num_tracing_attempts)
duration_ms, num_tracing_attempts, {})
@deprecated('2020-07-01', 'use `tf.profiler.experimental.client.monitor`.')

View File

@ -47,6 +47,25 @@ tensorflow::ProfileRequest MakeProfileRequest(
return request;
}
tensorflow::ProfileOptions GetOptions(const py::dict& opts) {
tensorflow::ProfileOptions options =
tensorflow::ProfilerSession::DefaultOptions();
for (const auto& kw : opts) {
std::string key = py::cast<std::string>(kw.first);
if (key == "host_tracer_level") {
options.set_host_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
} else if (key == "device_tracer_level") {
options.set_device_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "device_tracer_level set to " << options.device_tracer_level();
} else if (key == "python_tracer_level") {
options.set_python_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "python_tracer_level set to " << options.python_tracer_level();
}
}
return options;
}
class ProfilerSessionWrapper {
public:
void Start(const char* logdir, const py::dict& options) {
@ -93,23 +112,6 @@ class ProfilerSessionWrapper {
}
private:
tensorflow::ProfileOptions GetOptions(const py::dict& opts) {
tensorflow::ProfileOptions options =
tensorflow::ProfilerSession::DefaultOptions();
for (const auto& kw : opts) {
std::string key = py::cast<std::string>(kw.first);
if (key == "host_tracer_level") {
options.set_host_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "host_tracer_level set to " << options.host_tracer_level();
} else if (key == "python_tracer_level") {
options.set_python_tracer_level(py::cast<int>(kw.second));
VLOG(1) << "enable_python_tracer set to "
<< options.python_tracer_level();
}
}
return options;
}
std::unique_ptr<tensorflow::ProfilerSession> session_;
tensorflow::string logdir_;
};
@ -134,11 +136,12 @@ PYBIND11_MODULE(_pywrap_profiler, m) {
m.def("trace", [](const char* service_addr, const char* logdir,
const char* worker_list, bool include_dataset_ops,
int duration_ms, int num_tracing_attempts) {
int duration_ms, int num_tracing_attempts,
py::dict options) {
tensorflow::Status status =
tensorflow::profiler::ValidateHostPortPair(service_addr);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
tensorflow::ProfileOptions opts;
tensorflow::ProfileOptions opts = GetOptions(options);
opts.set_include_dataset_ops(include_dataset_ops);
status =
tensorflow::profiler::Trace(service_addr, logdir, worker_list,

View File

@ -30,7 +30,8 @@ def trace(service_addr,
logdir,
duration_ms,
worker_list='',
num_tracing_attempts=3):
num_tracing_attempts=3,
options=None):
"""Sends grpc requests to profiler server to perform on-demand profiling.
This method will block caller thread until it receives tracing result. This
@ -48,6 +49,8 @@ def trace(service_addr,
the current session (TPU only).
num_tracing_attempts: Optional. Automatically retry N times when no trace
event is collected (default 3).
options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous
profiler options.
Raises:
UnavailableError: If no trace event is collected.
@ -86,9 +89,10 @@ def trace(service_addr,
Open your browser and go to localhost:6006/#profile to view profiling results.
"""
opts = dict(options._asdict()) if options is not None else {}
_pywrap_profiler.trace(
_strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True,
duration_ms, num_tracing_attempts)
duration_ms, num_tracing_attempts, opts)
@tf_export('profiler.experimental.client.monitor', v1=[])

View File

@ -24,6 +24,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.profiler import profiler_client
from tensorflow.python.profiler import profiler_v2 as profiler
class ProfilerClientTest(test_util.TensorFlowTestCase):
@ -32,6 +33,13 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.UnavailableError):
profiler_client.trace('localhost:6006', tempfile.mkdtemp(), 2000)
def testStartTracing_ProcessInvalidAddressWithOptions(self):
with self.assertRaises(errors.UnavailableError):
options = profiler.ProfilerOptions(
host_tracer_level=3, device_tracer_level=0)
profiler_client.trace(
'localhost:6006', tempfile.mkdtemp(), 2000, options=options)
def testMonitor_ProcessInvalidAddress(self):
with self.assertRaises(errors.UnavailableError):
profiler_client.monitor('localhost:6006', 2000)

View File

@ -48,8 +48,9 @@ _profiler_lock = threading.Lock()
@tf_export('profiler.experimental.ProfilerOptions', v1=[])
class ProfilerOptions(
collections.namedtuple('ProfilerOptions',
['host_tracer_level', 'python_tracer_level'])):
collections.namedtuple(
'ProfilerOptions',
['host_tracer_level', 'python_tracer_level', 'device_tracer_level'])):
"""Options to control profiler behaviors.
A `tf.profiler.ProfilerOptions` hold the knobs to control tf.profiler's
@ -60,8 +61,18 @@ class ProfilerOptions(
2 => info, 3 => verbose. [default to 2]
python_tracer_level: for enable python function call tracing, 1 => enable.
0 => disable [default to 0]
device_tracer_level: for adjust device (TPU/GPU) tracer level, 0 => disable
1 => enabled. We may introduce fine-tuned level in the
future. [default to 1]
"""
pass
def __new__(cls,
host_tracer_level=2,
python_tracer_level=0,
device_tracer_level=1):
return super(ProfilerOptions,
cls).__new__(cls, host_tracer_level, python_tracer_level,
device_tracer_level)
@tf_export('profiler.experimental.start', v1=[])

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.profiler.profiler_v2.ProfilerOptions\'>"
is_instance: "<class \'tensorflow.python.profiler.profiler_v2.ProfilerOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "device_tracer_level"
mtype: "<type \'property\'>"
}
member {
name: "host_tracer_level"
mtype: "<type \'property\'>"

View File

@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "trace"
argspec: "args=[\'service_addr\', \'logdir\', \'duration_ms\', \'worker_list\', \'num_tracing_attempts\'], varargs=None, keywords=None, defaults=[\'\', \'3\'], "
argspec: "args=[\'service_addr\', \'logdir\', \'duration_ms\', \'worker_list\', \'num_tracing_attempts\', \'options\'], varargs=None, keywords=None, defaults=[\'\', \'3\', \'None\'], "
}
}