add a context manager and function decorator for monitoring time.
PiperOrigin-RevId: 313215897 Change-Id: I42aff9a8b95079a3c8929d32d747e778eba3c6dd
This commit is contained in:
parent
956278ab3d
commit
0e80859784
|
@ -19,6 +19,8 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
|
||||||
from tensorflow.core.framework import summary_pb2
|
from tensorflow.core.framework import summary_pb2
|
||||||
from tensorflow.python import pywrap_tfe
|
from tensorflow.python import pywrap_tfe
|
||||||
|
@ -428,3 +430,46 @@ class Sampler(Metric):
|
||||||
def get_cell(self, *labels):
|
def get_cell(self, *labels):
|
||||||
"""Retrieves the cell."""
|
"""Retrieves the cell."""
|
||||||
return SamplerCell(super(Sampler, self).get_cell(*labels))
|
return SamplerCell(super(Sampler, self).get_cell(*labels))
|
||||||
|
|
||||||
|
|
||||||
|
class MonitoredTimer(object):
|
||||||
|
"""A context manager to measure the walltime and increment a Counter cell."""
|
||||||
|
|
||||||
|
def __init__(self, cell):
|
||||||
|
"""Creates a new MonitoredTimer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: the cell associated with the time metric that will be inremented.
|
||||||
|
"""
|
||||||
|
self.cell = cell
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.t = time.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exception_type, exception_value, traceback):
|
||||||
|
del exception_type, exception_value, traceback
|
||||||
|
micro_seconds = (time.time() - self.t) * 1000000
|
||||||
|
self.cell.increase_by(int(micro_seconds))
|
||||||
|
|
||||||
|
|
||||||
|
def monitored_timer(cell):
|
||||||
|
"""A function decorator for adding MonitoredTimer support.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
cell: the cell associated with the time metric that will be inremented.
|
||||||
|
Returns:
|
||||||
|
A decorator that measure the function runtime and increment the specified
|
||||||
|
counter cell.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def actual_decorator(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with MonitoredTimer(cell):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return actual_decorator
|
||||||
|
|
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
from tensorflow.python.eager import monitoring
|
from tensorflow.python.eager import monitoring
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
@ -100,6 +102,26 @@ class MonitoringTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(histogram_proto1.num, 2.0)
|
self.assertEqual(histogram_proto1.num, 2.0)
|
||||||
self.assertEqual(histogram_proto1.sum, 6.0)
|
self.assertEqual(histogram_proto1.sum, 6.0)
|
||||||
|
|
||||||
|
def test_context_manager(self):
|
||||||
|
counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot')
|
||||||
|
with monitoring.MonitoredTimer(counter.get_cell('short')):
|
||||||
|
time.sleep(0.001)
|
||||||
|
with monitoring.MonitoredTimer(counter.get_cell('long')):
|
||||||
|
time.sleep(0.02)
|
||||||
|
self.assertGreater(
|
||||||
|
counter.get_cell('long').value(),
|
||||||
|
counter.get_cell('short').value())
|
||||||
|
|
||||||
|
def test_function_decorator(self):
|
||||||
|
counter = monitoring.Counter('test/funcdecorator', 'test func decorator')
|
||||||
|
|
||||||
|
@monitoring.monitored_timer(counter.get_cell())
|
||||||
|
def timed_function(seconds):
|
||||||
|
time.sleep(seconds)
|
||||||
|
|
||||||
|
timed_function(0.001)
|
||||||
|
self.assertGreater(counter.get_cell().value(), 1000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue