diff --git a/tensorflow/python/eager/monitoring.py b/tensorflow/python/eager/monitoring.py index 26d4d8a55b3..74d98558192 100644 --- a/tensorflow/python/eager/monitoring.py +++ b/tensorflow/python/eager/monitoring.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import collections +import functools +import time from tensorflow.core.framework import summary_pb2 from tensorflow.python import pywrap_tfe @@ -428,3 +430,46 @@ class Sampler(Metric): def get_cell(self, *labels): """Retrieves the cell.""" return SamplerCell(super(Sampler, self).get_cell(*labels)) + + +class MonitoredTimer(object): + """A context manager to measure the walltime and increment a Counter cell.""" + + def __init__(self, cell): + """Creates a new MonitoredTimer. + + Args: + cell: the cell associated with the time metric that will be inremented. + """ + self.cell = cell + + def __enter__(self): + self.t = time.time() + return self + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback + micro_seconds = (time.time() - self.t) * 1000000 + self.cell.increase_by(int(micro_seconds)) + + +def monitored_timer(cell): + """A function decorator for adding MonitoredTimer support. + + Arguments: + cell: the cell associated with the time metric that will be inremented. + Returns: + A decorator that measure the function runtime and increment the specified + counter cell. + """ + + def actual_decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with MonitoredTimer(cell): + return func(*args, **kwargs) + + return wrapper + + return actual_decorator diff --git a/tensorflow/python/eager/monitoring_test.py b/tensorflow/python/eager/monitoring_test.py index 3f601735ef2..7cb8c0c2cd1 100644 --- a/tensorflow/python/eager/monitoring_test.py +++ b/tensorflow/python/eager/monitoring_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.python.eager import monitoring from tensorflow.python.eager import test from tensorflow.python.framework import errors @@ -100,6 +102,26 @@ class MonitoringTest(test_util.TensorFlowTestCase): self.assertEqual(histogram_proto1.num, 2.0) self.assertEqual(histogram_proto1.sum, 6.0) + def test_context_manager(self): + counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot') + with monitoring.MonitoredTimer(counter.get_cell('short')): + time.sleep(0.001) + with monitoring.MonitoredTimer(counter.get_cell('long')): + time.sleep(0.02) + self.assertGreater( + counter.get_cell('long').value(), + counter.get_cell('short').value()) + + def test_function_decorator(self): + counter = monitoring.Counter('test/funcdecorator', 'test func decorator') + + @monitoring.monitored_timer(counter.get_cell()) + def timed_function(seconds): + time.sleep(seconds) + + timed_function(0.001) + self.assertGreater(counter.get_cell().value(), 1000) + if __name__ == '__main__': test.main()