add a context manager and function decorator for monitoring time.
PiperOrigin-RevId: 313215897 Change-Id: I42aff9a8b95079a3c8929d32d747e778eba3c6dd
This commit is contained in:
parent
956278ab3d
commit
0e80859784
|
@ -19,6 +19,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import time
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
|
@ -428,3 +430,46 @@ class Sampler(Metric):
|
|||
def get_cell(self, *labels):
|
||||
"""Retrieves the cell."""
|
||||
return SamplerCell(super(Sampler, self).get_cell(*labels))
|
||||
|
||||
|
||||
class MonitoredTimer(object):
|
||||
"""A context manager to measure the walltime and increment a Counter cell."""
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Creates a new MonitoredTimer.
|
||||
|
||||
Args:
|
||||
cell: the cell associated with the time metric that will be inremented.
|
||||
"""
|
||||
self.cell = cell
|
||||
|
||||
def __enter__(self):
|
||||
self.t = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
del exception_type, exception_value, traceback
|
||||
micro_seconds = (time.time() - self.t) * 1000000
|
||||
self.cell.increase_by(int(micro_seconds))
|
||||
|
||||
|
||||
def monitored_timer(cell):
|
||||
"""A function decorator for adding MonitoredTimer support.
|
||||
|
||||
Arguments:
|
||||
cell: the cell associated with the time metric that will be inremented.
|
||||
Returns:
|
||||
A decorator that measure the function runtime and increment the specified
|
||||
counter cell.
|
||||
"""
|
||||
|
||||
def actual_decorator(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with MonitoredTimer(cell):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return actual_decorator
|
||||
|
|
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import errors
|
||||
|
@ -100,6 +102,26 @@ class MonitoringTest(test_util.TensorFlowTestCase):
|
|||
self.assertEqual(histogram_proto1.num, 2.0)
|
||||
self.assertEqual(histogram_proto1.sum, 6.0)
|
||||
|
||||
def test_context_manager(self):
|
||||
counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot')
|
||||
with monitoring.MonitoredTimer(counter.get_cell('short')):
|
||||
time.sleep(0.001)
|
||||
with monitoring.MonitoredTimer(counter.get_cell('long')):
|
||||
time.sleep(0.02)
|
||||
self.assertGreater(
|
||||
counter.get_cell('long').value(),
|
||||
counter.get_cell('short').value())
|
||||
|
||||
def test_function_decorator(self):
|
||||
counter = monitoring.Counter('test/funcdecorator', 'test func decorator')
|
||||
|
||||
@monitoring.monitored_timer(counter.get_cell())
|
||||
def timed_function(seconds):
|
||||
time.sleep(seconds)
|
||||
|
||||
timed_function(0.001)
|
||||
self.assertGreater(counter.get_cell().value(), 1000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue