add a context manager and function decorator for monitoring time.

PiperOrigin-RevId: 313215897
Change-Id: I42aff9a8b95079a3c8929d32d747e778eba3c6dd
This commit is contained in:
A. Unique TensorFlower 2020-05-26 10:14:32 -07:00 committed by TensorFlower Gardener
parent 956278ab3d
commit 0e80859784
2 changed files with 67 additions and 0 deletions

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import collections
import functools
import time
from tensorflow.core.framework import summary_pb2
from tensorflow.python import pywrap_tfe
@ -428,3 +430,46 @@ class Sampler(Metric):
def get_cell(self, *labels):
"""Retrieves the cell."""
return SamplerCell(super(Sampler, self).get_cell(*labels))
class MonitoredTimer(object):
"""A context manager to measure the walltime and increment a Counter cell."""
def __init__(self, cell):
"""Creates a new MonitoredTimer.
Args:
cell: the cell associated with the time metric that will be inremented.
"""
self.cell = cell
def __enter__(self):
self.t = time.time()
return self
def __exit__(self, exception_type, exception_value, traceback):
del exception_type, exception_value, traceback
micro_seconds = (time.time() - self.t) * 1000000
self.cell.increase_by(int(micro_seconds))
def monitored_timer(cell):
"""A function decorator for adding MonitoredTimer support.
Arguments:
cell: the cell associated with the time metric that will be inremented.
Returns:
A decorator that measure the function runtime and increment the specified
counter cell.
"""
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with MonitoredTimer(cell):
return func(*args, **kwargs)
return wrapper
return actual_decorator

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.python.eager import monitoring
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
@ -100,6 +102,26 @@ class MonitoringTest(test_util.TensorFlowTestCase):
self.assertEqual(histogram_proto1.num, 2.0)
self.assertEqual(histogram_proto1.sum, 6.0)
def test_context_manager(self):
counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot')
with monitoring.MonitoredTimer(counter.get_cell('short')):
time.sleep(0.001)
with monitoring.MonitoredTimer(counter.get_cell('long')):
time.sleep(0.02)
self.assertGreater(
counter.get_cell('long').value(),
counter.get_cell('short').value())
def test_function_decorator(self):
counter = monitoring.Counter('test/funcdecorator', 'test func decorator')
@monitoring.monitored_timer(counter.get_cell())
def timed_function(seconds):
time.sleep(seconds)
timed_function(0.001)
self.assertGreater(counter.get_cell().value(), 1000)
if __name__ == '__main__':
test.main()