From 4ec6f2b07c08ddab479541cad0c61f169c1f816f Mon Sep 17 00:00:00 2001 From: Alexandre Passos <apassos@google.com> Date: Mon, 23 Oct 2017 09:59:21 -0700 Subject: [PATCH] Switching contrib.summaries API to be context-manager-centric PiperOrigin-RevId: 173129793 --- tensorflow/contrib/summary/summary_ops.py | 33 +++++++- .../contrib/summary/summary_ops_test.py | 79 ++++++++++--------- tensorflow/python/eager/context.py | 6 +- 3 files changed, 77 insertions(+), 41 deletions(-) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ba3619bfc90..30a9398ee54 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import utils from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util +from tensorflow.python.util import tf_contextlib # Name for a collection which is expected to have at most a single boolean # Tensor. If this tensor is True the summary ops will record summaries. @@ -46,22 +47,50 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. +@tf_contextlib.contextmanager def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [training_util.get_global_step() % n == 0] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [True] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + old = collection_ref[:] collection_ref[:] = [False] + yield + collection_ref[:] = old + + +class SummaryWriter(object): + + def __init__(self, resource): + self._resource = resource + + def set_as_default(self): + context.context().summary_writer_resource = self._resource + + @tf_contextlib.contextmanager + def as_default(self): + old = context.context().summary_writer_resource + context.context().summary_writer_resource = self._resource + yield + context.context().summary_writer_resource = old def create_summary_file_writer(logdir, @@ -77,9 +106,11 @@ def create_summary_file_writer(logdir, if filename_suffix is None: filename_suffix = constant_op.constant("") resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos) ensure the initialization op runs when in graph mode; consider + # calling session.run here. gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, flush_secs, filename_suffix) - context.context().summary_writer_resource = resource + return SummaryWriter(resource) def _nothing(): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 2cd4fce5b39..405a92a7263 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase): def testShouldRecordSummary(self): self.assertFalse(summary_ops.should_record_summaries()) - summary_ops.always_record_summaries() - self.assertTrue(summary_ops.should_record_summaries()) + with summary_ops.always_record_summaries(): + self.assertTrue(summary_ops.should_record_summaries()) def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') - summary_ops.always_record_summaries() - summary_ops.generic('tensor', 1, '') - summary_ops.scalar('scalar', 2.0) - summary_ops.histogram('histogram', [1.0]) - summary_ops.image('image', [[[[1.0]]]]) - summary_ops.audio('audio', [[1.0]], 1.0, 1) - # The working condition of the ops is tested in the C++ test so we just - # test here that we're calling them correctly. - self.assertTrue(gfile.Exists(logdir)) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') - summary_ops.always_record_summaries() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t1').as_default(), summary_ops.always_record_summaries(): - @function.defun - def write(): - summary_ops.scalar('scalar', 2.0) + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) - write() + write() - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 2.0) def testSummaryName(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2') - summary_ops.always_record_summaries() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): - summary_ops.scalar('scalar', 2.0) + summary_ops.scalar('scalar', 2.0) - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].tag, 'scalar') + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].tag, 'scalar') if __name__ == '__main__': diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index c5eedb7c9cf..92f4e15c054 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -58,6 +58,7 @@ class _EagerContext(threading.local): self.mode = _default_mode self.scope_name = "" self.recording_summaries = False + self.summary_writer_resource = None self.scalar_cache = {} @@ -86,7 +87,6 @@ class Context(object): self._eager_context = _EagerContext() self._context_handle = None self._context_devices = None - self._summary_writer_resource = None self._post_execution_callbacks = [] self._config = config self._seed = None @@ -213,12 +213,12 @@ class Context(object): @property def summary_writer_resource(self): """Returns summary writer resource.""" - return self._summary_writer_resource + return self._eager_context.summary_writer_resource @summary_writer_resource.setter def summary_writer_resource(self, resource): """Sets summary writer resource.""" - self._summary_writer_resource = resource + self._eager_context.summary_writer_resource = resource @property def device_name(self):