Add tf.profiler.experimental.Trace to the public API.
PiperOrigin-RevId: 300414493 Change-Id: I63e3b877faf57b17d177b682096d1340d0f456b4
This commit is contained in:
parent
5a7c6fb633
commit
c820d8a11f
tensorflow
python
tools/api/golden/v2
@ -207,6 +207,7 @@ py_library(
|
||||
"//tensorflow/python/profiler",
|
||||
"//tensorflow/python/profiler:profiler_client",
|
||||
"//tensorflow/python/profiler:profiler_v2",
|
||||
"//tensorflow/python/profiler:trace",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/tools:module_util",
|
||||
"//tensorflow/python/tools/api/generator:create_python_api",
|
||||
|
@ -103,6 +103,7 @@ from tensorflow.python.ops.signal import signal
|
||||
from tensorflow.python.profiler import profiler
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
from tensorflow.python.profiler import profiler_v2
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.tpu import api
|
||||
|
@ -217,13 +217,25 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "trace",
|
||||
srcs = ["trace.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "traceme",
|
||||
srcs = ["traceme.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||
"@six_archive//:six",
|
||||
":trace",
|
||||
],
|
||||
)
|
||||
|
||||
|
71
tensorflow/python/profiler/trace.py
Normal file
71
tensorflow/python/profiler/trace.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Trace allows the profiler to trace Python events."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('profiler.experimental.Trace', v1=[])
|
||||
class Trace(object):
|
||||
"""Context manager that generates a trace event in the profiler.
|
||||
|
||||
A trace event will start when entering the context, and stop and save the
|
||||
result to the profiler when exiting the context. Open TensorBoard Profile tab
|
||||
and choose trace viewer to view the trace event in the timeline.
|
||||
|
||||
Trace events are created only when the profiler is enabled. More information
|
||||
on how to use the profiler can be found at
|
||||
https://tensorflow.org/guide/profiler
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
tf.profiler.experimental.start('logdir')
|
||||
for step in range(num_steps):
|
||||
# Creates a trace event for each training step with the step number.
|
||||
with tf.profiler.experimental.Trace("Train", step_num=step):
|
||||
train_fn()
|
||||
tf.profiler.experimental.stop()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name, **kwargs):
|
||||
"""Creates a trace event in the profiler.
|
||||
|
||||
Args:
|
||||
name: The name of the trace event.
|
||||
**kwargs: Keyword arguments added to the trace event.
|
||||
"""
|
||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||
if kwargs:
|
||||
name += '#' + ','.join(key + '=' + str(value)
|
||||
for key, value in six.iteritems(kwargs)) + '#'
|
||||
self._traceme = _pywrap_traceme.TraceMe(name)
|
||||
else:
|
||||
self._traceme = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._traceme:
|
||||
self._traceme.Enter()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._traceme:
|
||||
self._traceme.Exit()
|
@ -12,41 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TraceMe allows the profiler to trace python events.
|
||||
|
||||
Usage:
|
||||
with profiler.TraceMe('name'):
|
||||
...
|
||||
"""
|
||||
"""TraceMe allows the profiler to trace Python events."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||
|
||||
|
||||
class TraceMe(object):
|
||||
"""Context manager that generates a trace event in the profiler."""
|
||||
|
||||
def __init__(self, name, **kwargs):
|
||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||
if kwargs:
|
||||
name += '#' + ','.join(key + '=' + str(value)
|
||||
for key, value in six.iteritems(kwargs)) + '#'
|
||||
self._traceme = _pywrap_traceme.TraceMe(name)
|
||||
else:
|
||||
self._traceme = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._traceme:
|
||||
self._traceme.Enter()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._traceme:
|
||||
self._traceme.Exit()
|
||||
from tensorflow.python.profiler.trace import Trace as TraceMe
|
||||
|
||||
|
||||
def traceme_wrapper(func):
|
||||
@ -58,4 +30,3 @@ def traceme_wrapper(func):
|
||||
with TraceMe(name):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.profiler.experimental.Trace"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.profiler.trace.Trace\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
}
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "Profile"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Trace"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "client"
|
||||
mtype: "<type \'module\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user