Add tf.profiler.experimental.Trace to the public API.

PiperOrigin-RevId: 300414493
Change-Id: I63e3b877faf57b17d177b682096d1340d0f456b4
This commit is contained in:
Jiho Choi 2020-03-11 15:01:46 -07:00 committed by TensorFlower Gardener
parent 5a7c6fb633
commit c820d8a11f
7 changed files with 102 additions and 33 deletions

View File

@ -207,6 +207,7 @@ py_library(
"//tensorflow/python/profiler",
"//tensorflow/python/profiler:profiler_client",
"//tensorflow/python/profiler:profiler_v2",
"//tensorflow/python/profiler:trace",
"//tensorflow/python/saved_model",
"//tensorflow/python/tools:module_util",
"//tensorflow/python/tools/api/generator:create_python_api",

View File

@ -103,6 +103,7 @@ from tensorflow.python.ops.signal import signal
from tensorflow.python.profiler import profiler
from tensorflow.python.profiler import profiler_client
from tensorflow.python.profiler import profiler_v2
from tensorflow.python.profiler import trace
from tensorflow.python.saved_model import saved_model
from tensorflow.python.summary import summary
from tensorflow.python.tpu import api

View File

@ -217,13 +217,25 @@ py_test(
],
)
py_library(
name = "trace",
srcs = ["trace.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_traceme",
"@six_archive//:six",
],
)
py_library(
name = "traceme",
srcs = ["traceme.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/profiler/internal:_pywrap_traceme",
"@six_archive//:six",
":trace",
],
)

View File

@ -0,0 +1,71 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trace allows the profiler to trace Python events."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.profiler.internal import _pywrap_traceme
from tensorflow.python.util.tf_export import tf_export
@tf_export('profiler.experimental.Trace', v1=[])
class Trace(object):
"""Context manager that generates a trace event in the profiler.
A trace event will start when entering the context, and stop and save the
result to the profiler when exiting the context. Open TensorBoard Profile tab
and choose trace viewer to view the trace event in the timeline.
Trace events are created only when the profiler is enabled. More information
on how to use the profiler can be found at
https://tensorflow.org/guide/profiler
Example usage:
```python
tf.profiler.experimental.start('logdir')
for step in range(num_steps):
# Creates a trace event for each training step with the step number.
with tf.profiler.experimental.Trace("Train", step_num=step):
train_fn()
tf.profiler.experimental.stop()
```
"""
def __init__(self, name, **kwargs):
"""Creates a trace event in the profiler.
Args:
name: The name of the trace event.
**kwargs: Keyword arguments added to the trace event.
"""
if _pywrap_traceme.TraceMe.IsEnabled():
if kwargs:
name += '#' + ','.join(key + '=' + str(value)
for key, value in six.iteritems(kwargs)) + '#'
self._traceme = _pywrap_traceme.TraceMe(name)
else:
self._traceme = None
def __enter__(self):
if self._traceme:
self._traceme.Enter()
def __exit__(self, exc_type, exc_val, exc_tb):
if self._traceme:
self._traceme.Exit()

View File

@ -12,41 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TraceMe allows the profiler to trace python events.
Usage:
with profiler.TraceMe('name'):
...
"""
"""TraceMe allows the profiler to trace Python events."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.profiler.internal import _pywrap_traceme
class TraceMe(object):
"""Context manager that generates a trace event in the profiler."""
def __init__(self, name, **kwargs):
if _pywrap_traceme.TraceMe.IsEnabled():
if kwargs:
name += '#' + ','.join(key + '=' + str(value)
for key, value in six.iteritems(kwargs)) + '#'
self._traceme = _pywrap_traceme.TraceMe(name)
else:
self._traceme = None
def __enter__(self):
if self._traceme:
self._traceme.Enter()
def __exit__(self, exc_type, exc_val, exc_tb):
if self._traceme:
self._traceme.Exit()
from tensorflow.python.profiler.trace import Trace as TraceMe
def traceme_wrapper(func):
@ -58,4 +30,3 @@ def traceme_wrapper(func):
with TraceMe(name):
return func(*args, **kwargs)
return wrapper

View File

@ -0,0 +1,9 @@
path: "tensorflow.profiler.experimental.Trace"
tf_class {
is_instance: "<class \'tensorflow.python.profiler.trace.Trace\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=kwargs, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "Profile"
mtype: "<type \'type\'>"
}
member {
name: "Trace"
mtype: "<type \'type\'>"
}
member {
name: "client"
mtype: "<type \'module\'>"